diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index c5c8bb7f..e03d4e0f 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -69,8 +69,8 @@ function executor(graph::Graph, input...) return exec end -function (exec::Exec)(input) - copy!(exec.args[exec.graph.input[1]], input) +function (exec::Exec)(input...) + foreach(kv -> copy!(exec.args[kv[1]], kv[2]), dictt(exec.graph.input, input)) mx.forward(exec.exec, is_train = true) mxungroup(exec.graph.output, copy(exec.outs)) end