split out makesession

This commit is contained in:
Mike J Innes 2016-10-30 12:29:00 +00:00
parent d5e0804669
commit 3b70ea6a42

View File

@ -7,14 +7,19 @@ end
cgroup(xs...) = Flow.group(map(constant, xs)...)
function tf(model::Flux.Unrolled)
function makesession(model::Flux.Unrolled)
sess = Session(Graph())
input = placeholder(Float32)
instates = [placeholder(Float32) for _ in model.state]
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
instates = [placeholder(Float32) for _ in model.state]
params, (outstates, outputs) = tograph(model.graph, cgroup(instates...), cgroup(inputs...))
output = TensorFlow.pack(outputs, axis = 1)
run(sess, initialize_all_variables())
sess, params, (instates, input), (outstates, output)
end
function tf(model::Flux.Unrolled)
sess, params, (instates, input), (outstates, output) = makesession(model)
SeqModel(
Model(model, sess, params,
[instates..., input], [outstates..., output],