From de72d83f7c616ad424a64965beb5ba9aec7d1ea4 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 30 Jan 2017 23:12:01 +0530 Subject: [PATCH] factor out node registration --- src/backend/mxnet/graph.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 8897940d..f235d7eb 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -59,12 +59,17 @@ end interp(ctx, p::Constant) = node(p.value) -function graph(ctx::Context, model, args...) - node = graph(model, args...) - isa(node, mx.SymbolicNode) && (ctx[:stacks][nodename(node)] = stack(ctx)) +function register(ctx::Context, node::mx.SymbolicNode) + ctx[:stacks][nodename(node)] = stack(ctx) return node end +register(ctx::Context, node) = node + +function graph(ctx::Context, model, args...) + register(ctx, graph(model, args...)) +end + function interp(ctx, model, args...) g = Flux.graph(model) g == nothing && return graph(ctx, model, args...)