fix back pass

This commit is contained in:
Mike J Innes 2017-03-30 19:36:59 +01:00
parent 4de61fc377
commit 94e384930d
2 changed files with 3 additions and 3 deletions

View File

@ -76,9 +76,9 @@ function (exec::Exec)(input...)
end
function Flux.back!(exec::Exec, Δ)
exec.grads[exec.graph.input[1]][:] = 0
mapt(k -> exec.grads[k][:] = 0, exec.graph.input)
mx.backward(exec.exec, MXArray(Δ).data)
copy(exec.grads[exec.graph.input[1]])
mapt(k -> copy(exec.grads[k]), exec.graph.input)
end
function Flux.update!(exec::Exec, η)

View File

@ -19,7 +19,7 @@ dm = mxnet(d)
@test dm(xs) d(xs)
Δ = back!(dm, randn(10), xs)
@test length(Δ) == 20
@test length(Δ[1]) == 20
update!(dm, 0.1)
@test dm(xs) d(xs)