fix back pass
This commit is contained in:
parent
2d77220d60
commit
a4812579e9
@ -51,7 +51,7 @@ end
|
|||||||
function mxnet(model::Flux.Model, input)
|
function mxnet(model::Flux.Model, input)
|
||||||
graph = tograph(model, mx.Variable(:input))
|
graph = tograph(model, mx.Variable(:input))
|
||||||
args = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
|
args = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
|
||||||
grads = mxparams(graph)
|
grads = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
|
||||||
model = @mxerr graph.stacks Model(model, graph, grads,
|
model = @mxerr graph.stacks Model(model, graph, grads,
|
||||||
mx.bind(graph.node, args = args,
|
mx.bind(graph.node, args = args,
|
||||||
args_grad = grads,
|
args_grad = grads,
|
||||||
@ -70,12 +70,18 @@ end
|
|||||||
|
|
||||||
(m::Model)(x) = first(m(batchone(x)))
|
(m::Model)(x) = first(m(batchone(x)))
|
||||||
|
|
||||||
function Flux.back!(model::Model, Δ, x)
|
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
|
||||||
ndzero!(model.grads[:input])
|
|
||||||
|
function runback!(model::Model, Δ)
|
||||||
|
model.grads[:input][:] = 0
|
||||||
mx.backward(model.exec, tond(Δ))
|
mx.backward(model.exec, tond(Δ))
|
||||||
copy(model.grads[:input])
|
copy(model.grads[:input])
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Flux.back!(m::Model, Δ::Batch, x) = rebatch(rebatch_first(runback!(m, rebatch_last(rawbatch(Δ)))))
|
||||||
|
|
||||||
|
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), x))
|
||||||
|
|
||||||
function Flux.update!(model::Model, η)
|
function Flux.update!(model::Model, η)
|
||||||
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
|
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
|
||||||
mx.@nd_as_jl rw = (arg, grad) begin
|
mx.@nd_as_jl rw = (arg, grad) begin
|
||||||
|
Loading…
Reference in New Issue
Block a user