update step for mxnet models
This commit is contained in:
parent
c1d85abfc2
commit
6e5e532cc1
@ -19,6 +19,8 @@ tond(xs::AArray) = tond!(mx.zeros(mxdims(size(xs))), xs)
|
|||||||
|
|
||||||
fromnd(xs::mx.NDArray) = copy(xs)'
|
fromnd(xs::mx.NDArray) = copy(xs)'
|
||||||
|
|
||||||
|
ndzero!(xs::mx.NDArray) = copy!(xs, mx.zeros(size(xs)))
|
||||||
|
|
||||||
function mxargs(args)
|
function mxargs(args)
|
||||||
map(args) do kv
|
map(args) do kv
|
||||||
arg, value = kv
|
arg, value = kv
|
||||||
@ -60,8 +62,17 @@ function (model::MXModel)(input)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Flux.back!(model::MXModel, Δ, x)
|
function Flux.back!(model::MXModel, Δ, x)
|
||||||
input = model.grads[:input]
|
ndzero!(model.grads[:input])
|
||||||
copy!(input, mx.zeros(size(input)))
|
|
||||||
mx.backward(model.exec, tond(Δ))
|
mx.backward(model.exec, tond(Δ))
|
||||||
fromnd(input)
|
fromnd(model.grads[:input])
|
||||||
|
end
|
||||||
|
|
||||||
|
function Flux.update!(model::MXModel, η)
|
||||||
|
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
|
||||||
|
mx.@nd_as_jl rw = (arg, grad) begin
|
||||||
|
arg .+= grad .* η
|
||||||
|
grad[:] = 0
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return model
|
||||||
end
|
end
|
||||||
|
@ -25,6 +25,7 @@ end
|
|||||||
|
|
||||||
function update!(p::Param, η)
|
function update!(p::Param, η)
|
||||||
p.x .+= p.Δx .* η
|
p.x .+= p.Δx .* η
|
||||||
|
p.Δx[:] = 0
|
||||||
return p
|
return p
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user