MXModel -> MX.Model

This commit is contained in:
Mike J Innes 2017-01-30 23:35:05 +05:30
parent 3cb3aea825
commit 16d6c9aed9

View File

@ -1,5 +1,5 @@
type MXModel <: Model type Model <: Flux.Model
model::Any model::Any
params::Dict{Symbol,Any} params::Dict{Symbol,Any}
grads::Dict{Symbol,Any} grads::Dict{Symbol,Any}
@ -25,18 +25,18 @@ function mxgrads(mxargs)
end end
end end
function loadparams!(model::MXModel) function loadparams!(model::Model)
for (name, arr) in model.exec.arg_dict for (name, arr) in model.exec.arg_dict
haskey(model.params, name) && copy!(arr, model.params[name]) haskey(model.params, name) && copy!(arr, model.params[name])
end end
return model return model
end end
function mxnet(model::Model, input) function mxnet(model::Flux.Model, input)
params, stacks, node = tograph(model, mx.Variable(:input)) params, stacks, node = tograph(model, mx.Variable(:input))
args = merge(mxargs(params), Dict(:input => mx.zeros(input))) args = merge(mxargs(params), Dict(:input => mx.zeros(input)))
grads = mxgrads(args) grads = mxgrads(args)
model = @mxerr stacks MXModel(model, params, grads, stacks, model = @mxerr stacks Model(model, params, grads, stacks,
mx.bind(node, args = args, mx.bind(node, args = args,
args_grad = grads, args_grad = grads,
grad_req = mx.GRAD_ADD)) grad_req = mx.GRAD_ADD))
@ -44,19 +44,19 @@ function mxnet(model::Model, input)
return model return model
end end
function (model::MXModel)(input) function runmodel(model::Model, input)
copy!(model.exec.arg_dict[:input], input) copy!(model.exec.arg_dict[:input], input)
mx.forward(model.exec, is_train = true) mx.forward(model.exec, is_train = true)
copy(model.exec.outputs[1]) copy(model.exec.outputs[1])
end end
function Flux.back!(model::MXModel, Δ, x) function Flux.back!(model::Model, Δ, x)
ndzero!(model.grads[:input]) ndzero!(model.grads[:input])
mx.backward(model.exec, tond(Δ)) mx.backward(model.exec, tond(Δ))
copy(model.grads[:input]) copy(model.grads[:input])
end end
function Flux.update!(model::MXModel, η) function Flux.update!(model::Model, η)
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays) for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
mx.@nd_as_jl rw = (arg, grad) begin mx.@nd_as_jl rw = (arg, grad) begin
arg .-= grad .* η arg .-= grad .* η
@ -81,7 +81,7 @@ function rewrite_softmax(model, name)
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1])) return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
end end
function mx.FeedForward(model::Model; input = :data, label = :softmax, context = mx.cpu()) function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
model = rewrite_softmax(model, label) model = rewrite_softmax(model, label)
node, vars = mxgraph(model, input) node, vars = mxgraph(model, input)
ff = mx.FeedForward(node, context = context) ff = mx.FeedForward(node, context = context)