2017-01-28 17:02:49 +00:00
|
|
|
using MacroTools
|
|
|
|
|
|
|
|
type MXModel <: Model
|
|
|
|
model::Any
|
|
|
|
params::Dict{Symbol,Any}
|
|
|
|
grads::Dict{Symbol,Any}
|
2017-01-29 10:39:30 +00:00
|
|
|
stack::Dict{Any,Any}
|
2017-01-28 17:02:49 +00:00
|
|
|
exec::mx.Executor
|
|
|
|
end
|
|
|
|
|
2017-01-28 17:37:02 +00:00
|
|
|
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
|
2017-01-28 17:02:49 +00:00
|
|
|
|
|
|
|
ndzero!(xs::mx.NDArray) = copy!(xs, mx.zeros(size(xs)))
|
|
|
|
|
|
|
|
function mxargs(args)
|
|
|
|
map(args) do kv
|
|
|
|
arg, value = kv
|
|
|
|
arg => tond(value)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
function mxgrads(mxargs)
|
|
|
|
map(mxargs) do kv
|
|
|
|
arg, value = kv
|
|
|
|
arg => mx.zeros(size(value))
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
function loadparams!(model::MXModel)
|
|
|
|
for (name, arr) in model.exec.arg_dict
|
2017-01-28 17:37:22 +00:00
|
|
|
haskey(model.params, name) && copy!(arr, model.params[name])
|
2017-01-28 17:02:49 +00:00
|
|
|
end
|
|
|
|
return model
|
|
|
|
end
|
|
|
|
|
|
|
|
function mxnet(model::Model, input)
|
|
|
|
params, stacks, node = tograph(model, mx.Variable(:input))
|
2017-01-28 17:37:02 +00:00
|
|
|
args = merge(mxargs(params), Dict(:input => mx.zeros(input)))
|
2017-01-28 17:02:49 +00:00
|
|
|
grads = mxgrads(args)
|
2017-01-29 10:39:30 +00:00
|
|
|
model = MXModel(model, params, grads, stacks,
|
2017-01-28 17:02:49 +00:00
|
|
|
mx.bind(node, args = args,
|
|
|
|
args_grad = grads,
|
|
|
|
grad_req = mx.GRAD_ADD))
|
|
|
|
loadparams!(model)
|
|
|
|
return model
|
|
|
|
end
|
|
|
|
|
|
|
|
function (model::MXModel)(input)
|
2017-01-28 17:37:22 +00:00
|
|
|
copy!(model.exec.arg_dict[:input], input)
|
2017-01-28 17:02:49 +00:00
|
|
|
mx.forward(model.exec, is_train = true)
|
2017-01-29 11:28:22 +00:00
|
|
|
copy(model.exec.outputs[1])
|
2017-01-28 17:02:49 +00:00
|
|
|
end
|
|
|
|
|
|
|
|
function Flux.back!(model::MXModel, Δ, x)
|
|
|
|
ndzero!(model.grads[:input])
|
|
|
|
mx.backward(model.exec, tond(Δ))
|
2017-01-29 11:28:22 +00:00
|
|
|
copy(model.grads[:input])
|
2017-01-28 17:02:49 +00:00
|
|
|
end
|
|
|
|
|
|
|
|
function Flux.update!(model::MXModel, η)
|
|
|
|
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
|
|
|
|
mx.@nd_as_jl rw = (arg, grad) begin
|
|
|
|
arg .-= grad .* η
|
|
|
|
grad[:] = 0
|
|
|
|
end
|
|
|
|
end
|
|
|
|
return model
|
|
|
|
end
|
|
|
|
|
|
|
|
# MX FeedForward interface
|
|
|
|
|
|
|
|
type SoftmaxOutput
|
|
|
|
name::Symbol
|
|
|
|
end
|
|
|
|
|
|
|
|
graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(data = xs, name = s.name)
|
|
|
|
|
|
|
|
function rewrite_softmax(model, name)
|
|
|
|
model == softmax && return SoftmaxOutput(name)
|
|
|
|
g = Flux.graph(model)
|
|
|
|
(g == nothing || value(g) ≠ softmax || DataFlow.nin(g) ≠ 1) && error("mx.FeedForward models must end with `softmax`")
|
|
|
|
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
|
|
|
|
end
|
|
|
|
|
|
|
|
function mx.FeedForward(model::Model; input = :data, label = :softmax, context = mx.cpu())
|
|
|
|
model = rewrite_softmax(model, label)
|
|
|
|
node, vars = mxgraph(model, input)
|
|
|
|
ff = mx.FeedForward(node, context = context)
|
|
|
|
ff.arg_params = mxargs(vars)
|
|
|
|
return ff
|
|
|
|
end
|