diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index 05bc21b7..bd618223 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -1,7 +1,7 @@ using Base: @get! -using DataFlow: Constant, constant, Context, interpret, Split, interptuple, - interpv, interplambda, interpconst, interpline, stack -using Flux: interpmap +using DataFlow: Constant, constant, Context, interpret, Split, + interpv, ituple, ilambda, iconst, iline, stack, mux +using Flux: imap using TensorFlow: RawTensor # TODO: implement Julia's type promotion rules @@ -61,7 +61,7 @@ function interp(ctx, model, args...) end function tograph(model, args...) - ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))), + ctx = Context(mux(iline, ilambda, ituple, imap, interp), params = ObjectIdDict(), stacks = Dict()) out = interp(ctx, model, map(constant, args)...) return ctx[:params], ctx[:stacks], out diff --git a/src/compiler/interp.jl b/src/compiler/interp.jl index f4d5737d..f271b3c2 100644 --- a/src/compiler/interp.jl +++ b/src/compiler/interp.jl @@ -1,4 +1,4 @@ -using DataFlow: interpret, interpv, interptuple, interplambda, interpconst, Context +using DataFlow: mux, interpret, interpv, ituple, ilambda, iconst, Context function astuple(xs::Vertex) isconstant(xs) && isa(value(xs).value, Tuple) ? value(xs).value : @@ -15,17 +15,16 @@ function astuples(xs) all(x->!(x==nothing), xs) ? xs : nothing end -function interpmap(cb) - function interp(ctx, ::typeof(map), f, xs...) - f, xs = interpv(ctx, (f, xs)) - xs′ = astuples(xs) - xs′ ≠ nothing ? - group(map(f, xs′...)...) : - cb(ctx, map, constant(f), xs...) - end - interp(args...) = cb(args...) +function imap(cb, ctx, ::typeof(map), f, xs...) + f, xs = interpv(ctx, (f, xs)) + xs′ = astuples(xs) + xs′ ≠ nothing ? + group(map(f, xs′...)...) : + cb(ctx, map, constant(f), xs...) end +imap(f, args...) = f(args...) + function interp(ctx, model, xs...) g = graph(model) g == nothing && return vertex(model, map(constant, interpv(ctx, xs))...) @@ -33,4 +32,4 @@ function interp(ctx, model, xs...) end expand(graph, xs...) = - interp(Context(interplambda(interpmap(interpconst(interptuple(interp))))), graph, xs...) + interp(Context(mux(ilambda, imap, iconst, ituple, interp)), graph, xs...)