fix back pass

This commit is contained in:
Mike J Innes 2017-02-23 21:06:46 +00:00
parent 2d77220d60
commit a4812579e9

View File

@ -50,8 +50,8 @@ end
function mxnet(model::Flux.Model, input)
graph = tograph(model, mx.Variable(:input))
args = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
grads = mxparams(graph)
args = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
grads = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
model = @mxerr graph.stacks Model(model, graph, grads,
mx.bind(graph.node, args = args,
args_grad = grads,
@ -70,12 +70,18 @@ end
(m::Model)(x) = first(m(batchone(x)))
function Flux.back!(model::Model, Δ, x)
ndzero!(model.grads[:input])
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
function runback!(model::Model, Δ)
model.grads[:input][:] = 0
mx.backward(model.exec, tond(Δ))
copy(model.grads[:input])
end
Flux.back!(m::Model, Δ::Batch, x) = rebatch(rebatch_first(runback!(m, rebatch_last(rawbatch(Δ)))))
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), x))
function Flux.update!(model::Model, η)
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
mx.@nd_as_jl rw = (arg, grad) begin