diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index f2d1b13b..0822ba9f 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -68,3 +68,11 @@ end TensorFlow.Tensor(m::Flux.Model, args...) = graph(Dict(), m, args...) RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data)) + +function makesession(model, n) + sess = Session(Graph()) + inputs = [placeholder(Float32) for _ = 1:n] + params, output = tograph(model, inputs...) + run(sess, initialize_all_variables()) + sess, params, inputs, output +end diff --git a/src/backend/tensorflow/model.jl b/src/backend/tensorflow/model.jl index eb7d1b69..ca52bf15 100644 --- a/src/backend/tensorflow/model.jl +++ b/src/backend/tensorflow/model.jl @@ -10,13 +10,10 @@ end ismultioutput(m::Model) = !isa(m.output, Tensor) function tf(model) - sess = Session(Graph()) - input = placeholder(Float32) - params, output = tograph(model, input) - run(sess, initialize_all_variables()) + sess, params, input, output = makesession(model, 1) Model(model, sess, params, - [input], output, - [gradients(output, input)]) + input, output, + gradients(output, input)) end batchone(x) = Batch((x,)) @@ -57,12 +54,12 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1, opt = TensorFlow.train.GradientDescentOptimizer(η)) i = 0 Y = placeholder(Float32) - Loss = loss(m.outputs[1], Y) + Loss = loss(m.output, Y) minimize_op = TensorFlow.train.minimize(opt, Loss) for e in 1:epoch info("Epoch $e\n") @progress for (x, y) in train - y, cur_loss, _ = run(m.session, vcat(m.outputs[1], Loss, minimize_op), + y, cur_loss, _ = run(m.session, vcat(m.output, Loss, minimize_op), Dict(m.inputs[1]=>batchone(x), Y=>batchone(y))) if i % 5000 == 0 @show y