diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index ef9fdaa7..8d8a7f9a 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -52,7 +52,10 @@ interp(ctx, c::Conv2D, x) = interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) = haskey(ctx[:params], p.value) ? ctx[:params][p.value] : - (ctx[:params][p.value] = Variable(convertel(Float32, p.value.x))) + (ctx[:params][p.value] = + ctx[:variables] ? + Variable(Float32.(p.value.x)) : + placeholder(Float32)) interp(ctx, p::Constant) = p.value @@ -63,14 +66,15 @@ function interp(ctx, model, args...) interpret(ctx, g, interpv(ctx, args)...) end -function tograph(model, args...) +function tograph(model, args...; variables = false) ctx = Context(mux(iline, ilambda, imap, interp), - params = ObjectIdDict(), stacks = Dict()) + params = ObjectIdDict(), stacks = Dict(), variables = variables) out = interp(ctx, model, map(constant, args)...) return ctx[:params], ctx[:stacks], out end -TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[3] +TensorFlow.Tensor(m::Flux.Model, args...) = + tograph(m, args...; variables = true)[3] RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data)) diff --git a/src/backend/tensorflow/model.jl b/src/backend/tensorflow/model.jl index 62fe16a8..c8f61850 100644 --- a/src/backend/tensorflow/model.jl +++ b/src/backend/tensorflow/model.jl @@ -22,7 +22,9 @@ dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys))) function (m::Exec)(args...) shapecheckt(m.input, args) - retuple(run(m.session, m.output, dictt(m.input, args))) + idict = dictt(m.input, args) + pdict = Dict(t => p.x for (p, t) in m.params) + retuple(run(m.session, m.output, merge(idict, pdict))) end mutable struct Model @@ -34,7 +36,7 @@ end tf(model) = Model(model) function (m::Model)(args...) - args = mapt(x->convert.(Float32, x), args) + args = mapt(x->Float32.(x), args) isdefined(m, :graph) || (m.exec = makesession(m.model, args)) @tferr m.exec.stacks m.exec(args...) end diff --git a/test/backend/tensorflow.jl b/test/backend/tensorflow.jl index 35a97d7d..5d75d210 100644 --- a/test/backend/tensorflow.jl +++ b/test/backend/tensorflow.jl @@ -27,7 +27,7 @@ end Y = Tensor(d, X) run(sess, global_variables_initializer()) - @test run(sess, Y, Dict(X=>Float32.(xs))) ≈ d(xs) + @test run(sess, Y, Dict(X=>xs)) ≈ d(xs) end @testset "Stack Traces" begin