From a1b1d8776759fc9b25aab7af1463bbe462c33d53 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Sun, 30 Oct 2016 01:58:39 +0100 Subject: [PATCH] update states references --- src/backend/tensorflow/recurrent.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/backend/tensorflow/recurrent.jl b/src/backend/tensorflow/recurrent.jl index 61e3c23d..bc75e760 100644 --- a/src/backend/tensorflow/recurrent.jl +++ b/src/backend/tensorflow/recurrent.jl @@ -10,7 +10,7 @@ cgroup(xs...) = Flow.group(map(constant, xs)...) function tf(model::Flux.Unrolled) sess = Session(Graph()) input = placeholder(Float32) - instates = [placeholder(Float32) for _ in model.states] + instates = [placeholder(Float32) for _ in model.state] inputs = TensorFlow.unpack(input, num = model.steps, axis = 1) params, (outstates, outputs) = tograph(model.graph, cgroup(instates...), cgroup(inputs...)) output = TensorFlow.pack(outputs, axis = 1) @@ -19,7 +19,7 @@ function tf(model::Flux.Unrolled) Model(model, sess, params, [instates..., input], [outstates..., output], [gradients(output, input)]), - batchone.(model.states)) + batchone.(model.state)) end function batchseq(xs)