record stack traces
This commit is contained in:
parent
8bf5d91605
commit
12d05a2db1
@ -1,4 +1,4 @@
|
||||
function symbolname(s::mx.SymbolicNode)
|
||||
function nodename(s::mx.SymbolicNode)
|
||||
name = Ref{mx.char_p}(0)
|
||||
success = Ref(0)
|
||||
mx.@mxcall(:MXSymbolGetName, (mx.MX_handle, Ref{mx.char_p}, Ref{Int}), s.handle.value, name, success)
|
||||
@ -62,7 +62,7 @@ interp(ctx, p::Constant) = node(p.value)
|
||||
|
||||
function graph(ctx::Context, model, args...)
|
||||
node = graph(model, interpv(ctx, args)...)
|
||||
# isa(node, Tensor) && (ctx[:stacks][node.op.name] = stack(ctx))
|
||||
isa(node, mx.SymbolicNode) && (ctx[:stacks][nodename(node)] = stack(ctx))
|
||||
return node
|
||||
end
|
||||
|
||||
|
@ -4,6 +4,7 @@ type MXModel <: Model
|
||||
model::Any
|
||||
params::Dict{Symbol,Any}
|
||||
grads::Dict{Symbol,Any}
|
||||
stack::Dict{Any,Any}
|
||||
exec::mx.Executor
|
||||
end
|
||||
|
||||
@ -38,7 +39,7 @@ function mxnet(model::Model, input)
|
||||
params, stacks, node = tograph(model, mx.Variable(:input))
|
||||
args = merge(mxargs(params), Dict(:input => mx.zeros(input)))
|
||||
grads = mxgrads(args)
|
||||
model = MXModel(model, params, grads,
|
||||
model = MXModel(model, params, grads, stacks,
|
||||
mx.bind(node, args = args,
|
||||
args_grad = grads,
|
||||
grad_req = mx.GRAD_ADD))
|
||||
|
Loading…
Reference in New Issue
Block a user