try to get biases working somewhat naturally
This commit is contained in:
parent
37003323ff
commit
a794f068a5
@ -19,7 +19,7 @@ node(x::mx.SymbolicNode) = x
|
||||
|
||||
graph(::typeof(tuple), args...) = (args...,)
|
||||
graph(::typeof(+), args...) = mx.broadcast_plus(args...)
|
||||
graph(::typeof(*), x, W) = mx.dot(transpose(W), x) # Adjustments for batching
|
||||
graph(::typeof(*), x, W) = mx.dot(W, x) # Adjustments for batching
|
||||
graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
|
||||
graph(::typeof(relu), x) = mx.Activation(data = x, act_type = :relu)
|
||||
graph(::typeof(tanh), x) = mx.Activation(data = x, act_type = :tanh)
|
||||
@ -38,7 +38,7 @@ graph(ctx::Context, d::Affine, x) =
|
||||
register(ctx,
|
||||
mx.FullyConnected(data = x,
|
||||
num_hidden = size(d.W.x, 2),
|
||||
weight = var(ctx, d.W),
|
||||
weight = var(ctx, d.W, size(d.W)),
|
||||
bias = var(ctx, d.b, size(d.b, 2))))
|
||||
|
||||
# TODO: use actual params}
|
||||
@ -63,7 +63,7 @@ register(ctx::Context, node) = node
|
||||
|
||||
function var(ctx::Context, p::Flux.Param, size = nothing)
|
||||
id = gensym()
|
||||
ctx[:params][id] = size == nothing ? p.x : reshape(p.x, size...)
|
||||
ctx[:params][id] = size == nothing ? rebatch_last(p.x) : reshape(p.x, size...)
|
||||
return mx.Variable(id)
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user