factor out node registration
This commit is contained in:
parent
ac1b686db5
commit
de72d83f7c
@ -59,12 +59,17 @@ end
|
||||
|
||||
interp(ctx, p::Constant) = node(p.value)
|
||||
|
||||
function graph(ctx::Context, model, args...)
|
||||
node = graph(model, args...)
|
||||
isa(node, mx.SymbolicNode) && (ctx[:stacks][nodename(node)] = stack(ctx))
|
||||
function register(ctx::Context, node::mx.SymbolicNode)
|
||||
ctx[:stacks][nodename(node)] = stack(ctx)
|
||||
return node
|
||||
end
|
||||
|
||||
register(ctx::Context, node) = node
|
||||
|
||||
function graph(ctx::Context, model, args...)
|
||||
register(ctx, graph(model, args...))
|
||||
end
|
||||
|
||||
function interp(ctx, model, args...)
|
||||
g = Flux.graph(model)
|
||||
g == nothing && return graph(ctx, model, args...)
|
||||
|
Loading…
Reference in New Issue
Block a user