diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 904b90dd..11c6db60 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -46,6 +46,12 @@ mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...) mxungroup(x, outs) = copy(shift!(outs)) mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x) +function collectt(xs) + ys = [] + mapt(x -> push!(ys, x), xs) + return ys +end + function dictt(ks::Tuple, vs, d = Dict()) for i = 1:length(ks) dictt(ks[i], vs[i], d) @@ -75,7 +81,7 @@ end function Flux.back!(exec::Exec, Δ) mapt(k -> exec.grads[k][:] = 0, exec.graph.input) - mx.backward(exec.exec, MXArray(Δ).data) + mx.backward(exec.exec, map(x -> MXArray(x).data, collectt(Δ))) mapt(k -> copy(exec.grads[k]), exec.graph.input) end