show compile error trace

This commit is contained in:
Mike J Innes 2017-01-29 17:29:37 +05:30
parent c2d6059d73
commit 3981485500
2 changed files with 23 additions and 4 deletions

View File

@ -79,3 +79,22 @@ function tograph(model, args...)
out = interp(ctx, model, map(constant, args)...)
return ctx[:params], ctx[:stacks], out
end
# Error Handling
function errnode(e::mx.MXError)
m = match(r"Error in (\w+):", e.msg)
m == nothing && return
Symbol(m.captures[1])
end
macro mxerr(stk, ex)
:(try
$(esc(ex))
catch e
(isa(e, mx.MXError) && (node = errnode(e)) != nothing) || rethrow()
stk = $(esc(stk))
@show stk[node]
rethrow()
end)
end

View File

@ -37,10 +37,10 @@ function mxnet(model::Model, input)
params, stacks, node = tograph(model, mx.Variable(:input))
args = merge(mxargs(params), Dict(:input => mx.zeros(input)))
grads = mxgrads(args)
model = MXModel(model, params, grads, stacks,
mx.bind(node, args = args,
args_grad = grads,
grad_req = mx.GRAD_ADD))
model = @mxerr stacks MXModel(model, params, grads, stacks,
mx.bind(node, args = args,
args_grad = grads,
grad_req = mx.GRAD_ADD))
loadparams!(model)
return model
end