diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 36b870f8..4c2f08de 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -1,5 +1,5 @@ -type MXModel <: Model +type Model <: Flux.Model model::Any params::Dict{Symbol,Any} grads::Dict{Symbol,Any} @@ -25,18 +25,18 @@ function mxgrads(mxargs) end end -function loadparams!(model::MXModel) +function loadparams!(model::Model) for (name, arr) in model.exec.arg_dict haskey(model.params, name) && copy!(arr, model.params[name]) end return model end -function mxnet(model::Model, input) +function mxnet(model::Flux.Model, input) params, stacks, node = tograph(model, mx.Variable(:input)) args = merge(mxargs(params), Dict(:input => mx.zeros(input))) grads = mxgrads(args) - model = @mxerr stacks MXModel(model, params, grads, stacks, + model = @mxerr stacks Model(model, params, grads, stacks, mx.bind(node, args = args, args_grad = grads, grad_req = mx.GRAD_ADD)) @@ -44,19 +44,19 @@ function mxnet(model::Model, input) return model end -function (model::MXModel)(input) +function runmodel(model::Model, input) copy!(model.exec.arg_dict[:input], input) mx.forward(model.exec, is_train = true) copy(model.exec.outputs[1]) end -function Flux.back!(model::MXModel, Δ, x) +function Flux.back!(model::Model, Δ, x) ndzero!(model.grads[:input]) mx.backward(model.exec, tond(Δ)) copy(model.grads[:input]) end -function Flux.update!(model::MXModel, η) +function Flux.update!(model::Model, η) for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays) mx.@nd_as_jl rw = (arg, grad) begin arg .-= grad .* η @@ -81,7 +81,7 @@ function rewrite_softmax(model, name) return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1])) end -function mx.FeedForward(model::Model; input = :data, label = :softmax, context = mx.cpu()) +function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu()) model = rewrite_softmax(model, label) node, vars = mxgraph(model, input) ff = mx.FeedForward(node, context = context)