From 4de16171db042e35a0239103d353efdfbadb0bcc Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Sat, 29 Oct 2016 00:10:27 +0100 Subject: [PATCH] basic sequence model training --- src/backend/tensorflow/recurrent.jl | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/backend/tensorflow/recurrent.jl b/src/backend/tensorflow/recurrent.jl index b7c39939..eb3c6967 100644 --- a/src/backend/tensorflow/recurrent.jl +++ b/src/backend/tensorflow/recurrent.jl @@ -19,7 +19,7 @@ function tf(model::Flux.Unrolled) Model(model, sess, params, [instates..., input], [outstates..., output], [gradients(output, input)]), - []) + batchone.(model.states)) end function batchseq(xs) @@ -40,3 +40,24 @@ function (m::SeqModel)(x::BatchSeq) end (m::SeqModel)(x::Seq) = first(m(batchone(x))) + +function Flux.train!(m::SeqModel, train; epoch = 1, η = 0.1, + loss = (y, y′) -> reduce_sum((y - y′).^2)/2, + opt = TensorFlow.train.GradientDescentOptimizer(η)) + i = 0 + Y = placeholder(Float32) + Loss = loss(m.m.output[end], Y) + minimize_op = TensorFlow.train.minimize(opt, Loss) + for e in 1:epoch + info("Epoch $e\n") + @progress for (x, y) in train + y, cur_loss, _ = run(m.m.session, vcat(m.m.output[end], Loss, minimize_op), + merge(Dict(m.m.inputs[end]=>batchone(x), Y=>batchone(y)), + Dict(zip(m.m.inputs[1:end-1], m.state)))) + if i % 5000 == 0 + @show y + end + i += 1 + end + end +end