fix back pass
This commit is contained in:
parent
2d77220d60
commit
a4812579e9
@ -50,8 +50,8 @@ end
|
||||
|
||||
function mxnet(model::Flux.Model, input)
|
||||
graph = tograph(model, mx.Variable(:input))
|
||||
args = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
|
||||
grads = mxparams(graph)
|
||||
args = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
|
||||
grads = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
|
||||
model = @mxerr graph.stacks Model(model, graph, grads,
|
||||
mx.bind(graph.node, args = args,
|
||||
args_grad = grads,
|
||||
@ -70,12 +70,18 @@ end
|
||||
|
||||
(m::Model)(x) = first(m(batchone(x)))
|
||||
|
||||
function Flux.back!(model::Model, Δ, x)
|
||||
ndzero!(model.grads[:input])
|
||||
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
|
||||
|
||||
function runback!(model::Model, Δ)
|
||||
model.grads[:input][:] = 0
|
||||
mx.backward(model.exec, tond(Δ))
|
||||
copy(model.grads[:input])
|
||||
end
|
||||
|
||||
Flux.back!(m::Model, Δ::Batch, x) = rebatch(rebatch_first(runback!(m, rebatch_last(rawbatch(Δ)))))
|
||||
|
||||
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), x))
|
||||
|
||||
function Flux.update!(model::Model, η)
|
||||
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
|
||||
mx.@nd_as_jl rw = (arg, grad) begin
|
||||
|
Loading…
Reference in New Issue
Block a user