fix tf train

This commit is contained in:
Mike J Innes 2017-04-19 14:48:10 +01:00
parent 42688f8aa8
commit f8a3b02c1d

View File

@ -54,7 +54,7 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
for e in 1:epoch for e in 1:epoch
info("Epoch $e\n") info("Epoch $e\n")
@progress for (x, y) in train @progress for (x, y) in train
y, cur_loss, _ = run(m.session, vcat(m.output, Loss, minimize_op), y, cur_loss, _ = run(m.session, [m.output, Loss, minimize_op],
Dict(m.inputs[1] => batchone(convertel(Float32, x)), Dict(m.inputs[1] => batchone(convertel(Float32, x)),
Y => batchone(convertel(Float32, y)))) Y => batchone(convertel(Float32, y))))
if i % 5000 == 0 if i % 5000 == 0