explicit hidden state batching

This commit is contained in:
Mike J Innes 2016-11-15 23:44:11 +00:00
parent 3c068744d2
commit 2d90d04789
2 changed files with 10 additions and 4 deletions

View File

@ -26,12 +26,16 @@ function batch(xs)
Batch{T,B}(xs)
end
function (m::Model)(args::Batch...)
function runmodel(m, args...)
@assert length(args) == length(m.inputs)
output = run(m.session, m.output, Dict(zip(m.inputs, args)))
ismultioutput(m) ? (batch.(output)...,) : batch(output)
end
function (m::Model)(args::Batch...)
runmodel(m, args...)
end
function (m::Model)(args...)
output = m(map(batchone, args)...)
ismultioutput(m) ? map(first, output) : first(output)

View File

@ -24,7 +24,7 @@ function tf(model::Flux.Unrolled)
Model(model, sess, params,
[instates..., input], [outstates..., output],
[placeholder(Float32)]),
batchone.(model.state))
model.state)
end
function batchseq(xs)
@ -35,11 +35,13 @@ function batchseq(xs)
Batch{Seq{T,S},B}(xs)
end
TensorFlow.get_tensors(x::Tuple) = TensorFlow.get_tensors(collect(x))
function (m::SeqModel)(x::BatchSeq)
if isempty(m.state) || length(first(m.state)) length(x)
m.state = map(batchone, m.m.model.states)
m.state = m.m.model.state
end
output = m.m(m.state..., x)
output = runmodel(m.m, m.state..., x)
m.state, output = output[1:end-1], output[end]
return batchseq(rawbatch(output))
end