explicit hidden state batching
This commit is contained in:
parent
3c068744d2
commit
2d90d04789
@ -26,12 +26,16 @@ function batch(xs)
|
||||
Batch{T,B}(xs)
|
||||
end
|
||||
|
||||
function (m::Model)(args::Batch...)
|
||||
function runmodel(m, args...)
|
||||
@assert length(args) == length(m.inputs)
|
||||
output = run(m.session, m.output, Dict(zip(m.inputs, args)))
|
||||
ismultioutput(m) ? (batch.(output)...,) : batch(output)
|
||||
end
|
||||
|
||||
function (m::Model)(args::Batch...)
|
||||
runmodel(m, args...)
|
||||
end
|
||||
|
||||
function (m::Model)(args...)
|
||||
output = m(map(batchone, args)...)
|
||||
ismultioutput(m) ? map(first, output) : first(output)
|
||||
|
@ -24,7 +24,7 @@ function tf(model::Flux.Unrolled)
|
||||
Model(model, sess, params,
|
||||
[instates..., input], [outstates..., output],
|
||||
[placeholder(Float32)]),
|
||||
batchone.(model.state))
|
||||
model.state)
|
||||
end
|
||||
|
||||
function batchseq(xs)
|
||||
@ -35,11 +35,13 @@ function batchseq(xs)
|
||||
Batch{Seq{T,S},B}(xs)
|
||||
end
|
||||
|
||||
TensorFlow.get_tensors(x::Tuple) = TensorFlow.get_tensors(collect(x))
|
||||
|
||||
function (m::SeqModel)(x::BatchSeq)
|
||||
if isempty(m.state) || length(first(m.state)) ≠ length(x)
|
||||
m.state = map(batchone, m.m.model.states)
|
||||
m.state = m.m.model.state
|
||||
end
|
||||
output = m.m(m.state..., x)
|
||||
output = runmodel(m.m, m.state..., x)
|
||||
m.state, output = output[1:end-1], output[end]
|
||||
return batchseq(rawbatch(output))
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user