explicit hidden state batching

This commit is contained in:
Mike J Innes 2016-11-15 23:44:11 +00:00
parent 3c068744d2
commit 2d90d04789
2 changed files with 10 additions and 4 deletions

View File

@ -26,12 +26,16 @@ function batch(xs)
Batch{T,B}(xs) Batch{T,B}(xs)
end end
function (m::Model)(args::Batch...) function runmodel(m, args...)
@assert length(args) == length(m.inputs) @assert length(args) == length(m.inputs)
output = run(m.session, m.output, Dict(zip(m.inputs, args))) output = run(m.session, m.output, Dict(zip(m.inputs, args)))
ismultioutput(m) ? (batch.(output)...,) : batch(output) ismultioutput(m) ? (batch.(output)...,) : batch(output)
end end
function (m::Model)(args::Batch...)
runmodel(m, args...)
end
function (m::Model)(args...) function (m::Model)(args...)
output = m(map(batchone, args)...) output = m(map(batchone, args)...)
ismultioutput(m) ? map(first, output) : first(output) ismultioutput(m) ? map(first, output) : first(output)

View File

@ -24,7 +24,7 @@ function tf(model::Flux.Unrolled)
Model(model, sess, params, Model(model, sess, params,
[instates..., input], [outstates..., output], [instates..., input], [outstates..., output],
[placeholder(Float32)]), [placeholder(Float32)]),
batchone.(model.state)) model.state)
end end
function batchseq(xs) function batchseq(xs)
@ -35,11 +35,13 @@ function batchseq(xs)
Batch{Seq{T,S},B}(xs) Batch{Seq{T,S},B}(xs)
end end
TensorFlow.get_tensors(x::Tuple) = TensorFlow.get_tensors(collect(x))
function (m::SeqModel)(x::BatchSeq) function (m::SeqModel)(x::BatchSeq)
if isempty(m.state) || length(first(m.state)) length(x) if isempty(m.state) || length(first(m.state)) length(x)
m.state = map(batchone, m.m.model.states) m.state = m.m.model.state
end end
output = m.m(m.state..., x) output = runmodel(m.m, m.state..., x)
m.state, output = output[1:end-1], output[end] m.state, output = output[1:end-1], output[end]
return batchseq(rawbatch(output)) return batchseq(rawbatch(output))
end end