update step for mxnet models

This commit is contained in:
Mike J Innes 2016-08-23 23:56:31 +01:00
parent c1d85abfc2
commit 6e5e532cc1
2 changed files with 15 additions and 3 deletions

View File

@ -19,6 +19,8 @@ tond(xs::AArray) = tond!(mx.zeros(mxdims(size(xs))), xs)
fromnd(xs::mx.NDArray) = copy(xs)'
ndzero!(xs::mx.NDArray) = copy!(xs, mx.zeros(size(xs)))
function mxargs(args)
map(args) do kv
arg, value = kv
@ -60,8 +62,17 @@ function (model::MXModel)(input)
end
function Flux.back!(model::MXModel, Δ, x)
input = model.grads[:input]
copy!(input, mx.zeros(size(input)))
ndzero!(model.grads[:input])
mx.backward(model.exec, tond(Δ))
fromnd(input)
fromnd(model.grads[:input])
end
function Flux.update!(model::MXModel, η)
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
mx.@nd_as_jl rw = (arg, grad) begin
arg .+= grad .* η
grad[:] = 0
end
end
return model
end

View File

@ -25,6 +25,7 @@ end
function update!(p::Param, η)
p.x .+= p.Δx .* η
p.Δx[:] = 0
return p
end