diff --git a/src/backend/tensorflow/recurrent.jl b/src/backend/tensorflow/recurrent.jl index 23afa265..0f6d9b81 100644 --- a/src/backend/tensorflow/recurrent.jl +++ b/src/backend/tensorflow/recurrent.jl @@ -7,14 +7,19 @@ end cgroup(xs...) = Flow.group(map(constant, xs)...) -function tf(model::Flux.Unrolled) +function makesession(model::Flux.Unrolled) sess = Session(Graph()) input = placeholder(Float32) - instates = [placeholder(Float32) for _ in model.state] inputs = TensorFlow.unpack(input, num = model.steps, axis = 1) + instates = [placeholder(Float32) for _ in model.state] params, (outstates, outputs) = tograph(model.graph, cgroup(instates...), cgroup(inputs...)) output = TensorFlow.pack(outputs, axis = 1) run(sess, initialize_all_variables()) + sess, params, (instates, input), (outstates, output) +end + +function tf(model::Flux.Unrolled) + sess, params, (instates, input), (outstates, output) = makesession(model) SeqModel( Model(model, sess, params, [instates..., input], [outstates..., output],