split out makesession logic

This commit is contained in:
Mike J Innes 2016-10-30 12:10:44 +00:00
parent 81d9743836
commit e433ffce8f
2 changed files with 13 additions and 8 deletions

View File

@ -68,3 +68,11 @@ end
TensorFlow.Tensor(m::Flux.Model, args...) = graph(Dict(), m, args...)
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
function makesession(model, n)
sess = Session(Graph())
inputs = [placeholder(Float32) for _ = 1:n]
params, output = tograph(model, inputs...)
run(sess, initialize_all_variables())
sess, params, inputs, output
end

View File

@ -10,13 +10,10 @@ end
ismultioutput(m::Model) = !isa(m.output, Tensor)
function tf(model)
sess = Session(Graph())
input = placeholder(Float32)
params, output = tograph(model, input)
run(sess, initialize_all_variables())
sess, params, input, output = makesession(model, 1)
Model(model, sess, params,
[input], output,
[gradients(output, input)])
input, output,
gradients(output, input))
end
batchone(x) = Batch((x,))
@ -57,12 +54,12 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
opt = TensorFlow.train.GradientDescentOptimizer(η))
i = 0
Y = placeholder(Float32)
Loss = loss(m.outputs[1], Y)
Loss = loss(m.output, Y)
minimize_op = TensorFlow.train.minimize(opt, Loss)
for e in 1:epoch
info("Epoch $e\n")
@progress for (x, y) in train
y, cur_loss, _ = run(m.session, vcat(m.outputs[1], Loss, minimize_op),
y, cur_loss, _ = run(m.session, vcat(m.output, Loss, minimize_op),
Dict(m.inputs[1]=>batchone(x), Y=>batchone(y)))
if i % 5000 == 0
@show y