interpreter middleware refactor
This commit is contained in:
parent
39398680b7
commit
147a26d045
@ -1,7 +1,7 @@
|
|||||||
using Base: @get!
|
using Base: @get!
|
||||||
using DataFlow: Constant, constant, Context, interpret, Split, interptuple,
|
using DataFlow: Constant, constant, Context, interpret, Split,
|
||||||
interpv, interplambda, interpconst, interpline, stack
|
interpv, ituple, ilambda, iconst, iline, stack, mux
|
||||||
using Flux: interpmap
|
using Flux: imap
|
||||||
using TensorFlow: RawTensor
|
using TensorFlow: RawTensor
|
||||||
|
|
||||||
# TODO: implement Julia's type promotion rules
|
# TODO: implement Julia's type promotion rules
|
||||||
@ -61,7 +61,7 @@ function interp(ctx, model, args...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function tograph(model, args...)
|
function tograph(model, args...)
|
||||||
ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))),
|
ctx = Context(mux(iline, ilambda, ituple, imap, interp),
|
||||||
params = ObjectIdDict(), stacks = Dict())
|
params = ObjectIdDict(), stacks = Dict())
|
||||||
out = interp(ctx, model, map(constant, args)...)
|
out = interp(ctx, model, map(constant, args)...)
|
||||||
return ctx[:params], ctx[:stacks], out
|
return ctx[:params], ctx[:stacks], out
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
using DataFlow: interpret, interpv, interptuple, interplambda, interpconst, Context
|
using DataFlow: mux, interpret, interpv, ituple, ilambda, iconst, Context
|
||||||
|
|
||||||
function astuple(xs::Vertex)
|
function astuple(xs::Vertex)
|
||||||
isconstant(xs) && isa(value(xs).value, Tuple) ? value(xs).value :
|
isconstant(xs) && isa(value(xs).value, Tuple) ? value(xs).value :
|
||||||
@ -15,17 +15,16 @@ function astuples(xs)
|
|||||||
all(x->!(x==nothing), xs) ? xs : nothing
|
all(x->!(x==nothing), xs) ? xs : nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
function interpmap(cb)
|
function imap(cb, ctx, ::typeof(map), f, xs...)
|
||||||
function interp(ctx, ::typeof(map), f, xs...)
|
f, xs = interpv(ctx, (f, xs))
|
||||||
f, xs = interpv(ctx, (f, xs))
|
xs′ = astuples(xs)
|
||||||
xs′ = astuples(xs)
|
xs′ ≠ nothing ?
|
||||||
xs′ ≠ nothing ?
|
group(map(f, xs′...)...) :
|
||||||
group(map(f, xs′...)...) :
|
cb(ctx, map, constant(f), xs...)
|
||||||
cb(ctx, map, constant(f), xs...)
|
|
||||||
end
|
|
||||||
interp(args...) = cb(args...)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
imap(f, args...) = f(args...)
|
||||||
|
|
||||||
function interp(ctx, model, xs...)
|
function interp(ctx, model, xs...)
|
||||||
g = graph(model)
|
g = graph(model)
|
||||||
g == nothing && return vertex(model, map(constant, interpv(ctx, xs))...)
|
g == nothing && return vertex(model, map(constant, interpv(ctx, xs))...)
|
||||||
@ -33,4 +32,4 @@ function interp(ctx, model, xs...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
expand(graph, xs...) =
|
expand(graph, xs...) =
|
||||||
interp(Context(interplambda(interpmap(interpconst(interptuple(interp))))), graph, xs...)
|
interp(Context(mux(ilambda, imap, iconst, ituple, interp)), graph, xs...)
|
||||||
|
Loading…
Reference in New Issue
Block a user