diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index da3b62f6..88472aa3 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -7,11 +7,6 @@ type MXModel <: Model exec::mx.Executor end -function tond!(nd::mx.NDArray, xs::AArray) - mx.copy_ignore_shape!(nd, xs) - nd -end - tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs) fromnd(xs::mx.NDArray) = copy(xs) @@ -34,7 +29,7 @@ end function loadparams!(model::MXModel) for (name, arr) in model.exec.arg_dict - haskey(model.params, name) && tond!(arr, model.params[name]) + haskey(model.params, name) && copy!(arr, model.params[name]) end return model end @@ -52,7 +47,7 @@ function mxnet(model::Model, input) end function (model::MXModel)(input) - tond!(model.exec.arg_dict[:input], input) + copy!(model.exec.arg_dict[:input], input) mx.forward(model.exec, is_train = true) fromnd(model.exec.outputs[1]) end