From a4812579e9aac55418e901234380c38e57edfb9d Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 23 Feb 2017 21:06:46 +0000 Subject: [PATCH] fix back pass --- src/backend/mxnet/model.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 835fed7b..b25636d0 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -50,8 +50,8 @@ end function mxnet(model::Flux.Model, input) graph = tograph(model, mx.Variable(:input)) - args = merge(mxparams(graph), Dict(:input => mx.zeros(input))) - grads = mxparams(graph) + args = merge(mxparams(graph), Dict(:input => mx.zeros(input))) + grads = merge(mxparams(graph), Dict(:input => mx.zeros(input))) model = @mxerr graph.stacks Model(model, graph, grads, mx.bind(graph.node, args = args, args_grad = grads, @@ -70,12 +70,18 @@ end (m::Model)(x) = first(m(batchone(x))) -function Flux.back!(model::Model, Δ, x) - ndzero!(model.grads[:input]) +tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs) + +function runback!(model::Model, Δ) + model.grads[:input][:] = 0 mx.backward(model.exec, tond(Δ)) copy(model.grads[:input]) end +Flux.back!(m::Model, Δ::Batch, x) = rebatch(rebatch_first(runback!(m, rebatch_last(rawbatch(Δ))))) + +Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), x)) + function Flux.update!(model::Model, η) for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays) mx.@nd_as_jl rw = (arg, grad) begin