mxnet back! for multi outputs

This commit is contained in:
Mike J Innes 2017-04-19 17:13:57 +01:00
parent 42a8117704
commit 358334a893

View File

@ -46,6 +46,12 @@ mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
mxungroup(x, outs) = copy(shift!(outs))
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
function collectt(xs)
ys = []
mapt(x -> push!(ys, x), xs)
return ys
end
function dictt(ks::Tuple, vs, d = Dict())
for i = 1:length(ks)
dictt(ks[i], vs[i], d)
@ -75,7 +81,7 @@ end
function Flux.back!(exec::Exec, Δ)
mapt(k -> exec.grads[k][:] = 0, exec.graph.input)
mx.backward(exec.exec, MXArray(Δ).data)
mx.backward(exec.exec, map(x -> MXArray(x).data, collectt(Δ)))
mapt(k -> copy(exec.grads[k]), exec.graph.input)
end