mxnet back! for multi outputs
This commit is contained in:
parent
42a8117704
commit
358334a893
@ -46,6 +46,12 @@ mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
|
||||
mxungroup(x, outs) = copy(shift!(outs))
|
||||
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
||||
|
||||
function collectt(xs)
|
||||
ys = []
|
||||
mapt(x -> push!(ys, x), xs)
|
||||
return ys
|
||||
end
|
||||
|
||||
function dictt(ks::Tuple, vs, d = Dict())
|
||||
for i = 1:length(ks)
|
||||
dictt(ks[i], vs[i], d)
|
||||
@ -75,7 +81,7 @@ end
|
||||
|
||||
function Flux.back!(exec::Exec, Δ)
|
||||
mapt(k -> exec.grads[k][:] = 0, exec.graph.input)
|
||||
mx.backward(exec.exec, MXArray(Δ).data)
|
||||
mx.backward(exec.exec, map(x -> MXArray(x).data, collectt(Δ)))
|
||||
mapt(k -> copy(exec.grads[k]), exec.graph.input)
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user