update states references
This commit is contained in:
parent
4d45ee1bb9
commit
a1b1d87767
@ -10,7 +10,7 @@ cgroup(xs...) = Flow.group(map(constant, xs)...)
|
|||||||
function tf(model::Flux.Unrolled)
|
function tf(model::Flux.Unrolled)
|
||||||
sess = Session(Graph())
|
sess = Session(Graph())
|
||||||
input = placeholder(Float32)
|
input = placeholder(Float32)
|
||||||
instates = [placeholder(Float32) for _ in model.states]
|
instates = [placeholder(Float32) for _ in model.state]
|
||||||
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
|
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
|
||||||
params, (outstates, outputs) = tograph(model.graph, cgroup(instates...), cgroup(inputs...))
|
params, (outstates, outputs) = tograph(model.graph, cgroup(instates...), cgroup(inputs...))
|
||||||
output = TensorFlow.pack(outputs, axis = 1)
|
output = TensorFlow.pack(outputs, axis = 1)
|
||||||
@ -19,7 +19,7 @@ function tf(model::Flux.Unrolled)
|
|||||||
Model(model, sess, params,
|
Model(model, sess, params,
|
||||||
[instates..., input], [outstates..., output],
|
[instates..., input], [outstates..., output],
|
||||||
[gradients(output, input)]),
|
[gradients(output, input)]),
|
||||||
batchone.(model.states))
|
batchone.(model.state))
|
||||||
end
|
end
|
||||||
|
|
||||||
function batchseq(xs)
|
function batchseq(xs)
|
||||||
|
Loading…
Reference in New Issue
Block a user