show compile error trace
This commit is contained in:
parent
c2d6059d73
commit
3981485500
@ -79,3 +79,22 @@ function tograph(model, args...)
|
|||||||
out = interp(ctx, model, map(constant, args)...)
|
out = interp(ctx, model, map(constant, args)...)
|
||||||
return ctx[:params], ctx[:stacks], out
|
return ctx[:params], ctx[:stacks], out
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Error Handling
|
||||||
|
|
||||||
|
function errnode(e::mx.MXError)
|
||||||
|
m = match(r"Error in (\w+):", e.msg)
|
||||||
|
m == nothing && return
|
||||||
|
Symbol(m.captures[1])
|
||||||
|
end
|
||||||
|
|
||||||
|
macro mxerr(stk, ex)
|
||||||
|
:(try
|
||||||
|
$(esc(ex))
|
||||||
|
catch e
|
||||||
|
(isa(e, mx.MXError) && (node = errnode(e)) != nothing) || rethrow()
|
||||||
|
stk = $(esc(stk))
|
||||||
|
@show stk[node]
|
||||||
|
rethrow()
|
||||||
|
end)
|
||||||
|
end
|
||||||
|
@ -37,10 +37,10 @@ function mxnet(model::Model, input)
|
|||||||
params, stacks, node = tograph(model, mx.Variable(:input))
|
params, stacks, node = tograph(model, mx.Variable(:input))
|
||||||
args = merge(mxargs(params), Dict(:input => mx.zeros(input)))
|
args = merge(mxargs(params), Dict(:input => mx.zeros(input)))
|
||||||
grads = mxgrads(args)
|
grads = mxgrads(args)
|
||||||
model = MXModel(model, params, grads, stacks,
|
model = @mxerr stacks MXModel(model, params, grads, stacks,
|
||||||
mx.bind(node, args = args,
|
mx.bind(node, args = args,
|
||||||
args_grad = grads,
|
args_grad = grads,
|
||||||
grad_req = mx.GRAD_ADD))
|
grad_req = mx.GRAD_ADD))
|
||||||
loadparams!(model)
|
loadparams!(model)
|
||||||
return model
|
return model
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user