try to get biases working somewhat naturally

This commit is contained in:
Mike J Innes 2017-02-21 16:07:58 +00:00
parent 37003323ff
commit a794f068a5

View File

@ -19,7 +19,7 @@ node(x::mx.SymbolicNode) = x
graph(::typeof(tuple), args...) = (args...,)
graph(::typeof(+), args...) = mx.broadcast_plus(args...)
graph(::typeof(*), x, W) = mx.dot(transpose(W), x) # Adjustments for batching
graph(::typeof(*), x, W) = mx.dot(W, x) # Adjustments for batching
graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
graph(::typeof(relu), x) = mx.Activation(data = x, act_type = :relu)
graph(::typeof(tanh), x) = mx.Activation(data = x, act_type = :tanh)
@ -38,7 +38,7 @@ graph(ctx::Context, d::Affine, x) =
register(ctx,
mx.FullyConnected(data = x,
num_hidden = size(d.W.x, 2),
weight = var(ctx, d.W),
weight = var(ctx, d.W, size(d.W)),
bias = var(ctx, d.b, size(d.b, 2))))
# TODO: use actual params}
@ -63,7 +63,7 @@ register(ctx::Context, node) = node
function var(ctx::Context, p::Flux.Param, size = nothing)
id = gensym()
ctx[:params][id] = size == nothing ? p.x : reshape(p.x, size...)
ctx[:params][id] = size == nothing ? rebatch_last(p.x) : reshape(p.x, size...)
return mx.Variable(id)
end