redundant
This commit is contained in:
parent
12d05a2db1
commit
c2d6059d73
@ -10,8 +10,6 @@ end
|
||||
|
||||
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
|
||||
|
||||
fromnd(xs::mx.NDArray) = copy(xs)
|
||||
|
||||
ndzero!(xs::mx.NDArray) = copy!(xs, mx.zeros(size(xs)))
|
||||
|
||||
function mxargs(args)
|
||||
@ -50,13 +48,13 @@ end
|
||||
function (model::MXModel)(input)
|
||||
copy!(model.exec.arg_dict[:input], input)
|
||||
mx.forward(model.exec, is_train = true)
|
||||
fromnd(model.exec.outputs[1])
|
||||
copy(model.exec.outputs[1])
|
||||
end
|
||||
|
||||
function Flux.back!(model::MXModel, Δ, x)
|
||||
ndzero!(model.grads[:input])
|
||||
mx.backward(model.exec, tond(Δ))
|
||||
fromnd(model.grads[:input])
|
||||
copy(model.grads[:input])
|
||||
end
|
||||
|
||||
function Flux.update!(model::MXModel, η)
|
||||
|
Loading…
Reference in New Issue
Block a user