diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 1e430f14..c60f1db6 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -85,7 +85,7 @@ function tograph(model, args...; feedforward = false) params = Dict(), stacks = Dict(), feedforward = feedforward) out = @ithrow graph(ctx, model, args...) - return ctx[:params], ctx[:stacks], out + return Graph(out, ctx[:params], ctx[:stacks]) end # Error Handling diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 4ae1f93f..10f3df17 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -1,10 +1,15 @@ using Flux: batchone, rebatch +type Graph + node::mx.SymbolicNode + params::Dict{Symbol,Any} + stacks::Dict{Any,Any} +end + type Model <: Flux.Model model::Any - params::Dict{Symbol,Any} + graph::Graph grads::Dict{Symbol,Any} - stack::Dict{Any,Any} exec::mx.Executor end @@ -30,19 +35,19 @@ end function loadparams!(model::Model) for (name, arr) in model.exec.arg_dict - haskey(model.params, name) && copy!(arr, model.params[name]) + haskey(model.graph.params, name) && copy!(arr, model.graph.params[name]) end return model end function mxnet(model::Flux.Model, input) - params, stacks, node = tograph(model, mx.Variable(:input)) - args = merge(mxargs(params), Dict(:input => mx.zeros(input))) + graph = tograph(model, mx.Variable(:input)) + args = merge(mxargs(graph.params), Dict(:input => mx.zeros(input))) grads = mxgrads(args) - model = @mxerr stacks Model(model, params, grads, stacks, - mx.bind(node, args = args, - args_grad = grads, - grad_req = mx.GRAD_ADD)) + model = @mxerr graph.stacks Model(model, graph, grads, + mx.bind(graph.node, args = args, + args_grad = grads, + grad_req = mx.GRAD_ADD)) loadparams!(model) return model end @@ -94,8 +99,8 @@ end function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu()) model = rewrite_softmax(model, label) - vars, stacks, node = tograph(model, mx.Variable(input), feedforward=true) - ff = mx.FeedForward(node, context = context) - isempty(vars) || (ff.arg_params = mxargs(vars)) + graph = tograph(model, mx.Variable(input), feedforward=true) + ff = mx.FeedForward(graph.node, context = context) + isempty(graph.params) || (ff.arg_params = mxargs(graph.params)) return ff end