From a794f068a531889745ad097097e4b8184dabb618 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 21 Feb 2017 16:07:58 +0000 Subject: [PATCH] try to get biases working somewhat naturally --- src/backend/mxnet/graph.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index ad238815..8e56b2c7 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -19,7 +19,7 @@ node(x::mx.SymbolicNode) = x graph(::typeof(tuple), args...) = (args...,) graph(::typeof(+), args...) = mx.broadcast_plus(args...) -graph(::typeof(*), x, W) = mx.dot(transpose(W), x) # Adjustments for batching +graph(::typeof(*), x, W) = mx.dot(W, x) # Adjustments for batching graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid) graph(::typeof(relu), x) = mx.Activation(data = x, act_type = :relu) graph(::typeof(tanh), x) = mx.Activation(data = x, act_type = :tanh) @@ -38,7 +38,7 @@ graph(ctx::Context, d::Affine, x) = register(ctx, mx.FullyConnected(data = x, num_hidden = size(d.W.x, 2), - weight = var(ctx, d.W), + weight = var(ctx, d.W, size(d.W)), bias = var(ctx, d.b, size(d.b, 2)))) # TODO: use actual params} @@ -63,7 +63,7 @@ register(ctx::Context, node) = node function var(ctx::Context, p::Flux.Param, size = nothing) id = gensym() - ctx[:params][id] = size == nothing ? p.x : reshape(p.x, size...) + ctx[:params][id] = size == nothing ? rebatch_last(p.x) : reshape(p.x, size...) return mx.Variable(id) end