fix tf train
This commit is contained in:
parent
42688f8aa8
commit
f8a3b02c1d
@ -54,7 +54,7 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
|
|||||||
for e in 1:epoch
|
for e in 1:epoch
|
||||||
info("Epoch $e\n")
|
info("Epoch $e\n")
|
||||||
@progress for (x, y) in train
|
@progress for (x, y) in train
|
||||||
y, cur_loss, _ = run(m.session, vcat(m.output, Loss, minimize_op),
|
y, cur_loss, _ = run(m.session, [m.output, Loss, minimize_op],
|
||||||
Dict(m.inputs[1] => batchone(convertel(Float32, x)),
|
Dict(m.inputs[1] => batchone(convertel(Float32, x)),
|
||||||
Y => batchone(convertel(Float32, y))))
|
Y => batchone(convertel(Float32, y))))
|
||||||
if i % 5000 == 0
|
if i % 5000 == 0
|
||||||
|
Loading…
Reference in New Issue
Block a user