interpreter middleware refactor

This commit is contained in:
Mike J Innes 2016-12-26 13:42:12 +00:00
parent 39398680b7
commit 147a26d045
2 changed files with 14 additions and 15 deletions

View File

@ -1,7 +1,7 @@
using Base: @get!
using DataFlow: Constant, constant, Context, interpret, Split, interptuple,
interpv, interplambda, interpconst, interpline, stack
using Flux: interpmap
using DataFlow: Constant, constant, Context, interpret, Split,
interpv, ituple, ilambda, iconst, iline, stack, mux
using Flux: imap
using TensorFlow: RawTensor
# TODO: implement Julia's type promotion rules
@ -61,7 +61,7 @@ function interp(ctx, model, args...)
end
function tograph(model, args...)
ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))),
ctx = Context(mux(iline, ilambda, ituple, imap, interp),
params = ObjectIdDict(), stacks = Dict())
out = interp(ctx, model, map(constant, args)...)
return ctx[:params], ctx[:stacks], out

View File

@ -1,4 +1,4 @@
using DataFlow: interpret, interpv, interptuple, interplambda, interpconst, Context
using DataFlow: mux, interpret, interpv, ituple, ilambda, iconst, Context
function astuple(xs::Vertex)
isconstant(xs) && isa(value(xs).value, Tuple) ? value(xs).value :
@ -15,17 +15,16 @@ function astuples(xs)
all(x->!(x==nothing), xs) ? xs : nothing
end
function interpmap(cb)
function interp(ctx, ::typeof(map), f, xs...)
f, xs = interpv(ctx, (f, xs))
xs = astuples(xs)
xs nothing ?
group(map(f, xs...)...) :
cb(ctx, map, constant(f), xs...)
end
interp(args...) = cb(args...)
function imap(cb, ctx, ::typeof(map), f, xs...)
f, xs = interpv(ctx, (f, xs))
xs = astuples(xs)
xs nothing ?
group(map(f, xs...)...) :
cb(ctx, map, constant(f), xs...)
end
imap(f, args...) = f(args...)
function interp(ctx, model, xs...)
g = graph(model)
g == nothing && return vertex(model, map(constant, interpv(ctx, xs))...)
@ -33,4 +32,4 @@ function interp(ctx, model, xs...)
end
expand(graph, xs...) =
interp(Context(interplambda(interpmap(interpconst(interptuple(interp))))), graph, xs...)
interp(Context(mux(ilambda, imap, iconst, ituple, interp)), graph, xs...)