diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 8897940d..f235d7eb 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -59,12 +59,17 @@ end interp(ctx, p::Constant) = node(p.value) -function graph(ctx::Context, model, args...) - node = graph(model, args...) - isa(node, mx.SymbolicNode) && (ctx[:stacks][nodename(node)] = stack(ctx)) +function register(ctx::Context, node::mx.SymbolicNode) + ctx[:stacks][nodename(node)] = stack(ctx) return node end +register(ctx::Context, node) = node + +function graph(ctx::Context, model, args...) + register(ctx, graph(model, args...)) +end + function interp(ctx, model, args...) g = Flux.graph(model) g == nothing && return graph(ctx, model, args...)