From b580c2e4a7ed90f10437bb59adc78af1037272be Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 20 Dec 2016 17:33:14 +0000 Subject: [PATCH] style improvement --- src/backend/tensorflow/recurrent.jl | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/backend/tensorflow/recurrent.jl b/src/backend/tensorflow/recurrent.jl index 18121038..5abee520 100644 --- a/src/backend/tensorflow/recurrent.jl +++ b/src/backend/tensorflow/recurrent.jl @@ -11,16 +11,17 @@ function makesession(model::Flux.Unrolled) sess = Session(Graph()) input = placeholder(Float32) inputs = TensorFlow.unpack(input, num = model.steps, axis = 1) - params, stacks, outputs, instates, outstates = [], [], [], [], [] - if model.stateful - instates = [placeholder(Float32) for _ in model.state] - params, stacks, (outstates, outputs) = tograph(model, cgroup(instates...), cgroup(inputs...)) - else - params, stacks, outputs = tograph(model, cgroup(inputs...)) + let params, stacks, outputs, instates, outstates + if model.stateful + instates = [placeholder(Float32) for _ in model.state] + params, stacks, (outstates, outputs) = tograph(model, cgroup(instates...), cgroup(inputs...)) + else + params, stacks, outputs = tograph(model, cgroup(inputs...)) + end + output = TensorFlow.pack(outputs, axis = 1) + run(sess, initialize_all_variables()) + sess, params, stacks, (instates, input), (outstates, output) end - output = TensorFlow.pack(outputs, axis = 1) - run(sess, initialize_all_variables()) - sess, params, stacks, (instates, input), (outstates, output) end function tf(model::Flux.Unrolled)