interpreter middleware refactor

This commit is contained in:
Mike J Innes 2016-12-26 13:42:12 +00:00
parent 39398680b7
commit 147a26d045
2 changed files with 14 additions and 15 deletions

View File

@ -1,7 +1,7 @@
using Base: @get! using Base: @get!
using DataFlow: Constant, constant, Context, interpret, Split, interptuple, using DataFlow: Constant, constant, Context, interpret, Split,
interpv, interplambda, interpconst, interpline, stack interpv, ituple, ilambda, iconst, iline, stack, mux
using Flux: interpmap using Flux: imap
using TensorFlow: RawTensor using TensorFlow: RawTensor
# TODO: implement Julia's type promotion rules # TODO: implement Julia's type promotion rules
@ -61,7 +61,7 @@ function interp(ctx, model, args...)
end end
function tograph(model, args...) function tograph(model, args...)
ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))), ctx = Context(mux(iline, ilambda, ituple, imap, interp),
params = ObjectIdDict(), stacks = Dict()) params = ObjectIdDict(), stacks = Dict())
out = interp(ctx, model, map(constant, args)...) out = interp(ctx, model, map(constant, args)...)
return ctx[:params], ctx[:stacks], out return ctx[:params], ctx[:stacks], out

View File

@ -1,4 +1,4 @@
using DataFlow: interpret, interpv, interptuple, interplambda, interpconst, Context using DataFlow: mux, interpret, interpv, ituple, ilambda, iconst, Context
function astuple(xs::Vertex) function astuple(xs::Vertex)
isconstant(xs) && isa(value(xs).value, Tuple) ? value(xs).value : isconstant(xs) && isa(value(xs).value, Tuple) ? value(xs).value :
@ -15,17 +15,16 @@ function astuples(xs)
all(x->!(x==nothing), xs) ? xs : nothing all(x->!(x==nothing), xs) ? xs : nothing
end end
function interpmap(cb) function imap(cb, ctx, ::typeof(map), f, xs...)
function interp(ctx, ::typeof(map), f, xs...) f, xs = interpv(ctx, (f, xs))
f, xs = interpv(ctx, (f, xs)) xs = astuples(xs)
xs = astuples(xs) xs nothing ?
xs nothing ? group(map(f, xs...)...) :
group(map(f, xs...)...) : cb(ctx, map, constant(f), xs...)
cb(ctx, map, constant(f), xs...)
end
interp(args...) = cb(args...)
end end
imap(f, args...) = f(args...)
function interp(ctx, model, xs...) function interp(ctx, model, xs...)
g = graph(model) g = graph(model)
g == nothing && return vertex(model, map(constant, interpv(ctx, xs))...) g == nothing && return vertex(model, map(constant, interpv(ctx, xs))...)
@ -33,4 +32,4 @@ function interp(ctx, model, xs...)
end end
expand(graph, xs...) = expand(graph, xs...) =
interp(Context(interplambda(interpmap(interpconst(interptuple(interp))))), graph, xs...) interp(Context(mux(ilambda, imap, iconst, ituple, interp)), graph, xs...)