fix back pass
This commit is contained in:
parent
4de61fc377
commit
94e384930d
|
@ -76,9 +76,9 @@ function (exec::Exec)(input...)
|
|||
end
|
||||
|
||||
function Flux.back!(exec::Exec, Δ)
|
||||
exec.grads[exec.graph.input[1]][:] = 0
|
||||
mapt(k -> exec.grads[k][:] = 0, exec.graph.input)
|
||||
mx.backward(exec.exec, MXArray(Δ).data)
|
||||
copy(exec.grads[exec.graph.input[1]])
|
||||
mapt(k -> copy(exec.grads[k]), exec.graph.input)
|
||||
end
|
||||
|
||||
function Flux.update!(exec::Exec, η)
|
||||
|
|
|
@ -19,7 +19,7 @@ dm = mxnet(d)
|
|||
@test dm(xs) ≈ d′(xs)
|
||||
|
||||
Δ = back!(dm, randn(10), xs)
|
||||
@test length(Δ) == 20
|
||||
@test length(Δ[1]) == 20
|
||||
update!(dm, 0.1)
|
||||
|
||||
@test dm(xs) ≈ d(xs)
|
||||
|
|
Loading…
Reference in New Issue