update states references
This commit is contained in:
parent
4d45ee1bb9
commit
a1b1d87767
@ -10,7 +10,7 @@ cgroup(xs...) = Flow.group(map(constant, xs)...)
|
||||
function tf(model::Flux.Unrolled)
|
||||
sess = Session(Graph())
|
||||
input = placeholder(Float32)
|
||||
instates = [placeholder(Float32) for _ in model.states]
|
||||
instates = [placeholder(Float32) for _ in model.state]
|
||||
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
|
||||
params, (outstates, outputs) = tograph(model.graph, cgroup(instates...), cgroup(inputs...))
|
||||
output = TensorFlow.pack(outputs, axis = 1)
|
||||
@ -19,7 +19,7 @@ function tf(model::Flux.Unrolled)
|
||||
Model(model, sess, params,
|
||||
[instates..., input], [outstates..., output],
|
||||
[gradients(output, input)]),
|
||||
batchone.(model.states))
|
||||
batchone.(model.state))
|
||||
end
|
||||
|
||||
function batchseq(xs)
|
||||
|
Loading…
Reference in New Issue
Block a user