use iargs
This commit is contained in:
parent
ce04a9f3c1
commit
a73b53e05e
@ -8,7 +8,7 @@ end
|
||||
|
||||
using Base: @get!
|
||||
using DataFlow: Constant, constant, Context, interpret, Split,
|
||||
interpv, ituple, ilambda, iconst, iline, stack, mux
|
||||
interpv, ituple, ilambda, iconst, iline, iargs, stack, mux
|
||||
using Flux: imap
|
||||
|
||||
# TODO: implement Julia's type promotion rules
|
||||
@ -61,7 +61,7 @@ end
|
||||
interp(ctx, p::Constant) = node(p.value)
|
||||
|
||||
function graph(ctx::Context, model, args...)
|
||||
node = graph(model, interpv(ctx, args)...)
|
||||
node = graph(model, args...)
|
||||
isa(node, mx.SymbolicNode) && (ctx[:stacks][nodename(node)] = stack(ctx))
|
||||
return node
|
||||
end
|
||||
@ -70,11 +70,11 @@ function interp(ctx, model, args...)
|
||||
g = Flux.graph(model)
|
||||
g == nothing && return graph(ctx, model, args...)
|
||||
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||||
interpret(ctx, g, interpv(ctx, args)...)
|
||||
interpret(ctx, g, args...)
|
||||
end
|
||||
|
||||
function tograph(model, args...)
|
||||
ctx = Context(mux(iline, ilambda, ituple, imap, interp),
|
||||
ctx = Context(mux(iline, ilambda, imap, iargs, interp),
|
||||
params = Dict(), stacks = Dict())
|
||||
out = interp(ctx, model, map(constant, args)...)
|
||||
return ctx[:params], ctx[:stacks], out
|
||||
|
Loading…
Reference in New Issue
Block a user