From 358334a8938d1d11dafe260de043b86a9490b254 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 19 Apr 2017 17:13:57 +0100 Subject: [PATCH] mxnet back! for multi outputs --- src/backend/mxnet/model.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 904b90dd..11c6db60 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -46,6 +46,12 @@ mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...) mxungroup(x, outs) = copy(shift!(outs)) mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x) +function collectt(xs) + ys = [] + mapt(x -> push!(ys, x), xs) + return ys +end + function dictt(ks::Tuple, vs, d = Dict()) for i = 1:length(ks) dictt(ks[i], vs[i], d) @@ -75,7 +81,7 @@ end function Flux.back!(exec::Exec, Δ) mapt(k -> exec.grads[k][:] = 0, exec.graph.input) - mx.backward(exec.exec, MXArray(Δ).data) + mx.backward(exec.exec, map(x -> MXArray(x).data, collectt(Δ))) mapt(k -> copy(exec.grads[k]), exec.graph.input) end