record stack traces

This commit is contained in:
Mike J Innes 2017-01-29 16:09:30 +05:30
parent 8bf5d91605
commit 12d05a2db1
2 changed files with 4 additions and 3 deletions

View File

@ -1,4 +1,4 @@
function symbolname(s::mx.SymbolicNode)
function nodename(s::mx.SymbolicNode)
name = Ref{mx.char_p}(0)
success = Ref(0)
mx.@mxcall(:MXSymbolGetName, (mx.MX_handle, Ref{mx.char_p}, Ref{Int}), s.handle.value, name, success)
@ -62,7 +62,7 @@ interp(ctx, p::Constant) = node(p.value)
function graph(ctx::Context, model, args...)
node = graph(model, interpv(ctx, args)...)
# isa(node, Tensor) && (ctx[:stacks][node.op.name] = stack(ctx))
isa(node, mx.SymbolicNode) && (ctx[:stacks][nodename(node)] = stack(ctx))
return node
end

View File

@ -4,6 +4,7 @@ type MXModel <: Model
model::Any
params::Dict{Symbol,Any}
grads::Dict{Symbol,Any}
stack::Dict{Any,Any}
exec::mx.Executor
end
@ -38,7 +39,7 @@ function mxnet(model::Model, input)
params, stacks, node = tograph(model, mx.Variable(:input))
args = merge(mxargs(params), Dict(:input => mx.zeros(input)))
grads = mxgrads(args)
model = MXModel(model, params, grads,
model = MXModel(model, params, grads, stacks,
mx.bind(node, args = args,
args_grad = grads,
grad_req = mx.GRAD_ADD))