factor out node registration
This commit is contained in:
parent
ac1b686db5
commit
de72d83f7c
@ -59,12 +59,17 @@ end
|
|||||||
|
|
||||||
interp(ctx, p::Constant) = node(p.value)
|
interp(ctx, p::Constant) = node(p.value)
|
||||||
|
|
||||||
function graph(ctx::Context, model, args...)
|
function register(ctx::Context, node::mx.SymbolicNode)
|
||||||
node = graph(model, args...)
|
ctx[:stacks][nodename(node)] = stack(ctx)
|
||||||
isa(node, mx.SymbolicNode) && (ctx[:stacks][nodename(node)] = stack(ctx))
|
|
||||||
return node
|
return node
|
||||||
end
|
end
|
||||||
|
|
||||||
|
register(ctx::Context, node) = node
|
||||||
|
|
||||||
|
function graph(ctx::Context, model, args...)
|
||||||
|
register(ctx, graph(model, args...))
|
||||||
|
end
|
||||||
|
|
||||||
function interp(ctx, model, args...)
|
function interp(ctx, model, args...)
|
||||||
g = Flux.graph(model)
|
g = Flux.graph(model)
|
||||||
g == nothing && return graph(ctx, model, args...)
|
g == nothing && return graph(ctx, model, args...)
|
||||||
|
Loading…
Reference in New Issue
Block a user