implement mxnet backward pass
This commit is contained in:
parent
2635283bf1
commit
8c7e74bf9f
@ -1,6 +1,7 @@
|
|||||||
type MXModel
|
type MXModel
|
||||||
model::Any
|
model::Any
|
||||||
params::Dict{Symbol,Any}
|
params::Dict{Symbol,Any}
|
||||||
|
grads::Dict{Symbol,Any}
|
||||||
exec::mx.Executor
|
exec::mx.Executor
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -9,6 +10,15 @@ mxdims(dims::NTuple) =
|
|||||||
|
|
||||||
mxdims(n::Integer) = mxdims((n,))
|
mxdims(n::Integer) = mxdims((n,))
|
||||||
|
|
||||||
|
function tond!(nd::mx.NDArray, xs::AArray)
|
||||||
|
mx.copy_ignore_shape!(nd, xs')
|
||||||
|
nd
|
||||||
|
end
|
||||||
|
|
||||||
|
tond(xs::AArray) = tond!(mx.zeros(mxdims(size(xs))), xs)
|
||||||
|
|
||||||
|
fromnd(xs::mx.NDArray) = copy(xs)'
|
||||||
|
|
||||||
function mxargs(args)
|
function mxargs(args)
|
||||||
map(args) do kv
|
map(args) do kv
|
||||||
arg, value = kv
|
arg, value = kv
|
||||||
@ -16,10 +26,16 @@ function mxargs(args)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function mxgrads(mxargs)
|
||||||
|
map(mxargs) do kv
|
||||||
|
arg, value = kv
|
||||||
|
arg => mx.zeros(size(value))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
function load!(model::MXModel)
|
function load!(model::MXModel)
|
||||||
for (name, arr) in model.exec.arg_dict
|
for (name, arr) in model.exec.arg_dict
|
||||||
# TODO: don't allocate here
|
haskey(model.params, name) && tond!(arr, model.params[name])
|
||||||
haskey(model.params, name) && mx.copy_ignore_shape!(arr, model.params[name]')
|
|
||||||
end
|
end
|
||||||
return model
|
return model
|
||||||
end
|
end
|
||||||
@ -28,14 +44,24 @@ function mxnet(model::Model, input)
|
|||||||
vars = Dict{Symbol,Any}()
|
vars = Dict{Symbol,Any}()
|
||||||
node = graph(vars, model, mx.Variable(:input))
|
node = graph(vars, model, mx.Variable(:input))
|
||||||
args = merge(mxargs(vars), Dict(:input => mx.zeros(mxdims(input))))
|
args = merge(mxargs(vars), Dict(:input => mx.zeros(mxdims(input))))
|
||||||
model = MXModel(model, vars, mx.bind(node, args = args, grad_req = mx.GRAD_NOP))
|
grads = mxgrads(args)
|
||||||
|
model = MXModel(model, vars, grads,
|
||||||
|
mx.bind(node, args = args,
|
||||||
|
args_grad = grads,
|
||||||
|
grad_req = mx.GRAD_ADD))
|
||||||
load!(model)
|
load!(model)
|
||||||
return model
|
return model
|
||||||
end
|
end
|
||||||
|
|
||||||
function (model::MXModel)(input)
|
function (model::MXModel)(input)
|
||||||
inputnd = model.exec.arg_dict[:input]
|
tond!(model.exec.arg_dict[:input], input)
|
||||||
mx.copy_ignore_shape!(inputnd, input')
|
|
||||||
mx.forward(model.exec)
|
mx.forward(model.exec)
|
||||||
copy(model.exec.outputs[1])'
|
fromnd(model.exec.outputs[1])
|
||||||
|
end
|
||||||
|
|
||||||
|
function Flux.back!(model::MXModel, Δ, x)
|
||||||
|
input = model.grads[:input]
|
||||||
|
copy!(input, mx.zeros(size(input)))
|
||||||
|
mx.backward(model.exec, tond(Δ))
|
||||||
|
fromnd(input)
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user