remove tond
This commit is contained in:
parent
bb70f401be
commit
70168319eb
@ -7,11 +7,6 @@ type MXModel <: Model
|
|||||||
exec::mx.Executor
|
exec::mx.Executor
|
||||||
end
|
end
|
||||||
|
|
||||||
function tond!(nd::mx.NDArray, xs::AArray)
|
|
||||||
mx.copy_ignore_shape!(nd, xs)
|
|
||||||
nd
|
|
||||||
end
|
|
||||||
|
|
||||||
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
|
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
|
||||||
|
|
||||||
fromnd(xs::mx.NDArray) = copy(xs)
|
fromnd(xs::mx.NDArray) = copy(xs)
|
||||||
@ -34,7 +29,7 @@ end
|
|||||||
|
|
||||||
function loadparams!(model::MXModel)
|
function loadparams!(model::MXModel)
|
||||||
for (name, arr) in model.exec.arg_dict
|
for (name, arr) in model.exec.arg_dict
|
||||||
haskey(model.params, name) && tond!(arr, model.params[name])
|
haskey(model.params, name) && copy!(arr, model.params[name])
|
||||||
end
|
end
|
||||||
return model
|
return model
|
||||||
end
|
end
|
||||||
@ -52,7 +47,7 @@ function mxnet(model::Model, input)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function (model::MXModel)(input)
|
function (model::MXModel)(input)
|
||||||
tond!(model.exec.arg_dict[:input], input)
|
copy!(model.exec.arg_dict[:input], input)
|
||||||
mx.forward(model.exec, is_train = true)
|
mx.forward(model.exec, is_train = true)
|
||||||
fromnd(model.exec.outputs[1])
|
fromnd(model.exec.outputs[1])
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user