split out makesession logic
This commit is contained in:
parent
81d9743836
commit
e433ffce8f
@ -68,3 +68,11 @@ end
|
|||||||
TensorFlow.Tensor(m::Flux.Model, args...) = graph(Dict(), m, args...)
|
TensorFlow.Tensor(m::Flux.Model, args...) = graph(Dict(), m, args...)
|
||||||
|
|
||||||
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
||||||
|
|
||||||
|
function makesession(model, n)
|
||||||
|
sess = Session(Graph())
|
||||||
|
inputs = [placeholder(Float32) for _ = 1:n]
|
||||||
|
params, output = tograph(model, inputs...)
|
||||||
|
run(sess, initialize_all_variables())
|
||||||
|
sess, params, inputs, output
|
||||||
|
end
|
||||||
|
@ -10,13 +10,10 @@ end
|
|||||||
ismultioutput(m::Model) = !isa(m.output, Tensor)
|
ismultioutput(m::Model) = !isa(m.output, Tensor)
|
||||||
|
|
||||||
function tf(model)
|
function tf(model)
|
||||||
sess = Session(Graph())
|
sess, params, input, output = makesession(model, 1)
|
||||||
input = placeholder(Float32)
|
|
||||||
params, output = tograph(model, input)
|
|
||||||
run(sess, initialize_all_variables())
|
|
||||||
Model(model, sess, params,
|
Model(model, sess, params,
|
||||||
[input], output,
|
input, output,
|
||||||
[gradients(output, input)])
|
gradients(output, input))
|
||||||
end
|
end
|
||||||
|
|
||||||
batchone(x) = Batch((x,))
|
batchone(x) = Batch((x,))
|
||||||
@ -57,12 +54,12 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
|
|||||||
opt = TensorFlow.train.GradientDescentOptimizer(η))
|
opt = TensorFlow.train.GradientDescentOptimizer(η))
|
||||||
i = 0
|
i = 0
|
||||||
Y = placeholder(Float32)
|
Y = placeholder(Float32)
|
||||||
Loss = loss(m.outputs[1], Y)
|
Loss = loss(m.output, Y)
|
||||||
minimize_op = TensorFlow.train.minimize(opt, Loss)
|
minimize_op = TensorFlow.train.minimize(opt, Loss)
|
||||||
for e in 1:epoch
|
for e in 1:epoch
|
||||||
info("Epoch $e\n")
|
info("Epoch $e\n")
|
||||||
@progress for (x, y) in train
|
@progress for (x, y) in train
|
||||||
y, cur_loss, _ = run(m.session, vcat(m.outputs[1], Loss, minimize_op),
|
y, cur_loss, _ = run(m.session, vcat(m.output, Loss, minimize_op),
|
||||||
Dict(m.inputs[1]=>batchone(x), Y=>batchone(y)))
|
Dict(m.inputs[1]=>batchone(x), Y=>batchone(y)))
|
||||||
if i % 5000 == 0
|
if i % 5000 == 0
|
||||||
@show y
|
@show y
|
||||||
|
Loading…
Reference in New Issue
Block a user