update states references

This commit is contained in:
Mike J Innes 2016-10-30 01:58:39 +01:00
parent 4d45ee1bb9
commit a1b1d87767

View File

@ -10,7 +10,7 @@ cgroup(xs...) = Flow.group(map(constant, xs)...)
function tf(model::Flux.Unrolled)
sess = Session(Graph())
input = placeholder(Float32)
instates = [placeholder(Float32) for _ in model.states]
instates = [placeholder(Float32) for _ in model.state]
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
params, (outstates, outputs) = tograph(model.graph, cgroup(instates...), cgroup(inputs...))
output = TensorFlow.pack(outputs, axis = 1)
@ -19,7 +19,7 @@ function tf(model::Flux.Unrolled)
Model(model, sess, params,
[instates..., input], [outstates..., output],
[gradients(output, input)]),
batchone.(model.states))
batchone.(model.state))
end
function batchseq(xs)