diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 5016173b..4e4161ae 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -79,3 +79,22 @@ function tograph(model, args...) out = interp(ctx, model, map(constant, args)...) return ctx[:params], ctx[:stacks], out end + +# Error Handling + +function errnode(e::mx.MXError) + m = match(r"Error in (\w+):", e.msg) + m == nothing && return + Symbol(m.captures[1]) +end + +macro mxerr(stk, ex) + :(try + $(esc(ex)) + catch e + (isa(e, mx.MXError) && (node = errnode(e)) != nothing) || rethrow() + stk = $(esc(stk)) + @show stk[node] + rethrow() + end) +end diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index d27a94b6..cebb1832 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -37,10 +37,10 @@ function mxnet(model::Model, input) params, stacks, node = tograph(model, mx.Variable(:input)) args = merge(mxargs(params), Dict(:input => mx.zeros(input))) grads = mxgrads(args) - model = MXModel(model, params, grads, stacks, - mx.bind(node, args = args, - args_grad = grads, - grad_req = mx.GRAD_ADD)) + model = @mxerr stacks MXModel(model, params, grads, stacks, + mx.bind(node, args = args, + args_grad = grads, + grad_req = mx.GRAD_ADD)) loadparams!(model) return model end