diff --git a/src/backend/tensorflow/model.jl b/src/backend/tensorflow/model.jl index 2ec5ebe9..fe10a1af 100644 --- a/src/backend/tensorflow/model.jl +++ b/src/backend/tensorflow/model.jl @@ -1,39 +1,40 @@ -struct Model - model::Any +using Flux: mapt + +struct Exec session::Session - params::Dict{Flux.Param,Tensor} - stacks::Dict - inputs::Vector{Tensor} + input::Any output::Any + params::Dict{Flux.Param,Tensor} + stacks::Dict{Any,Any} end function makesession(model, inputs; session = Session(Graph())) params, stacks, output = tograph(model, inputs...) run(session, global_variables_initializer()) - Model(model, session, params, stacks, inputs, output) + Exec(session, inputs, output, params, stacks) end function makesession(model, n::Integer; session = Session(Graph())) makesession(model, [placeholder(Float32) for _ = 1:n], session = session) end -tf(model) = makesession(model, 1) - -function storeparams!(sess, params) - for (p, t) in params - p.x = run(sess, t) - end +function (m::Exec)(args...) + @assert length(args) == length(m.input) + run(m.session, m.output, Dict(zip(m.input, args))) end -storeparams!(m::Model) = storeparams!(m.session, m.params) - -function runmodel(m::Model, args...) - @assert length(args) == length(m.inputs) - run(m.session, m.output, Dict(zip(m.inputs, args))) +mutable struct Model + model::Any + exec::Exec + Model(model) = new(model) end -function (m::Model)(x) - @tferr m.stacks runmodel(m, convert.(Float32, x)) +tf(model) = Model(model) + +function (m::Model)(args...) + args = mapt(x->convert.(Float32, x),args) + isdefined(m, :graph) || (m.exec = makesession(m.model, length(args))) + @tferr m.exec.stacks m.exec(args...) end for f in :[back!, update!].args @@ -41,27 +42,3 @@ for f in :[back!, update!].args error($(string(f)) * " is not yet supported on TensorFlow models") end end - -import Juno: info - -function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1, - loss = (y, y′) -> reduce_sum((y - y′).^2)/2, - opt = TensorFlow.train.GradientDescentOptimizer(η)) - i = 0 - Y = placeholder(Float32) - Loss = loss(m.output, Y) - minimize_op = TensorFlow.train.minimize(opt, Loss) - for e in 1:epoch - info("Epoch $e\n") - @progress for (x, y) in train - y, cur_loss, _ = run(m.session, [m.output, Loss, minimize_op], - Dict(m.inputs[1] => batchone(convertel(Float32, x)), - Y => batchone(convertel(Float32, y)))) - if i % 5000 == 0 - @show y - @show accuracy(m, test) - end - i += 1 - end - end -end diff --git a/test/backend/tensorflow.jl b/test/backend/tensorflow.jl index 81a9ad98..54f5560b 100644 --- a/test/backend/tensorflow.jl +++ b/test/backend/tensorflow.jl @@ -3,12 +3,16 @@ Flux.loadtf() @testset "TensorFlow" begin -xs = rand(1, 20) +xs, ys = rand(1, 20), rand(1, 20) d = Affine(20, 10) dt = tf(d) @test d(xs) ≈ dt(xs) +m = Multi(20, 15) +mm = tf(m) +@test all(isapprox.(mm(xs, ys), m(xs, ys))) + @testset "Tensor interface" begin sess = TensorFlow.Session() X = placeholder(Float32)