diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 0f7df34f..c5c8bb7f 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -30,7 +30,7 @@ function mxparams(g::Graph) return params end -ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d) +ndparams(d) = Dict{Symbol,mx.NDArray}(k => v.data for (k, v) in d) struct Exec graph::Graph @@ -48,9 +48,18 @@ mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...) mxungroup(x, outs) = copy(shift!(outs)) mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x) -function executor(graph::Graph, input) - args = merge(mxparams(graph), Dict(graph.input[1] => MXArray(input))) - grads = merge(mxparams(graph), Dict(graph.input[1] => MXArray(input))) +function dictt(ks::Tuple, vs, d = Dict()) + for i = 1:length(ks) + dictt(ks[i], vs[i], d) + end + return d +end + +dictt(k, v, d = Dict()) = (d[k] = v; d) + +function executor(graph::Graph, input...) + args = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input))) + grads = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input))) exec = mx.bind(mxgroup(graph.output), args = ndparams(args), args_grad = ndparams(grads), @@ -97,13 +106,13 @@ mxnet(model) = Model(model) import Base: @get! -executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input)) +executor(m::Model, input) = @get!(m.execs, size(input), executor(m.graph, input)) function (m::Model)(x) !isdefined(m, :graph) && (m.graph = tograph(m.model, mapt(_ -> gensym("input"), input))) @mxerr m.graph.stacks runrawbatched(x) do x - m.last = exec = executor(m, size(x)) + m.last = exec = executor(m, x) exec(x) end end