factor out node registration

This commit is contained in:
Mike J Innes 2017-01-30 23:12:01 +05:30
parent ac1b686db5
commit de72d83f7c

View File

@ -59,12 +59,17 @@ end
interp(ctx, p::Constant) = node(p.value)
function graph(ctx::Context, model, args...)
node = graph(model, args...)
isa(node, mx.SymbolicNode) && (ctx[:stacks][nodename(node)] = stack(ctx))
function register(ctx::Context, node::mx.SymbolicNode)
ctx[:stacks][nodename(node)] = stack(ctx)
return node
end
register(ctx::Context, node) = node
function graph(ctx::Context, model, args...)
register(ctx, graph(model, args...))
end
function interp(ctx, model, args...)
g = Flux.graph(model)
g == nothing && return graph(ctx, model, args...)