diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index c320c63a..5482ea64 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -1,5 +1,5 @@ using Base: @get! -using DataFlow: Constant, constant, Split +using DataFlow: constant, Split using DataFlow.Interpreter using DataFlow.Interpreter: stack using TensorFlow: RawTensor, TFException @@ -53,33 +53,34 @@ graph(p::MaxPool, x) = graph(op::Op, xs...) = op.f(xs...) function graph(ctx::Context, model, args...) - node = graph(model, interpv(ctx, args)...) + node = graph(model, args...) node isa Tensor && (ctx[:stacks][node.op.name] = stack(ctx)) return node end interp(ctx, c::Conv2D, x) = - nn.conv2d(interpv(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID") + nn.conv2d(x, interp(ctx, constant(c.filter)), [1,c.stride...,1], "VALID") -interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) = - haskey(ctx[:params], p.value) ? - ctx[:params][p.value] : - (ctx[:params][p.value] = +param(ctx, p::Flux.Param{<:AArray}) = + haskey(ctx[:params], p) ? + ctx[:params][p] : + (ctx[:params][p] = ctx[:variables] ? - Variable(Float32.(p.value.x)) : + Variable(Float32.(p.x)) : placeholder(Float32)) -interp(ctx, p::Constant) = p.value +param(ctx, x) = x function interp(ctx, model, args...) + args = param.(ctx, args) g = Flux.graph(model) g == nothing && return graph(ctx, model, args...) DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.") - interpret(ctx, g, interpv(ctx, args)...) + interpret(ctx, g, args...) end function tograph(model, args...; variables = false) - ctx = Context(mux(iline, ilambda, interp), + ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, interp), params = ObjectIdDict(), stacks = Dict(), variables = variables) out = interp(ctx, model, map(constant, args)...) return ctx[:params], ctx[:stacks], out