diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 0e6b9aa4..5016173b 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -1,4 +1,4 @@ -function symbolname(s::mx.SymbolicNode) +function nodename(s::mx.SymbolicNode) name = Ref{mx.char_p}(0) success = Ref(0) mx.@mxcall(:MXSymbolGetName, (mx.MX_handle, Ref{mx.char_p}, Ref{Int}), s.handle.value, name, success) @@ -62,7 +62,7 @@ interp(ctx, p::Constant) = node(p.value) function graph(ctx::Context, model, args...) node = graph(model, interpv(ctx, args)...) - # isa(node, Tensor) && (ctx[:stacks][node.op.name] = stack(ctx)) + isa(node, mx.SymbolicNode) && (ctx[:stacks][nodename(node)] = stack(ctx)) return node end diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 88472aa3..24ce7cfe 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -4,6 +4,7 @@ type MXModel <: Model model::Any params::Dict{Symbol,Any} grads::Dict{Symbol,Any} + stack::Dict{Any,Any} exec::mx.Executor end @@ -38,7 +39,7 @@ function mxnet(model::Model, input) params, stacks, node = tograph(model, mx.Variable(:input)) args = merge(mxargs(params), Dict(:input => mx.zeros(input))) grads = mxgrads(args) - model = MXModel(model, params, grads, + model = MXModel(model, params, grads, stacks, mx.bind(node, args = args, args_grad = grads, grad_req = mx.GRAD_ADD))