fix model call

This commit is contained in:
Mike J Innes 2016-12-15 18:35:11 +00:00
parent 1b22d55401
commit c6fb9c1f0c
1 changed files with 1 additions and 1 deletions

View File

@ -47,7 +47,7 @@ TensorFlow.get_tensors(x::Tuple) = TensorFlow.get_tensors(collect(x))
function (m::SeqModel)(x::BatchSeq)
m.m.model.stateful || return batchseq(runmodel(m.m, x)[end])
if isempty(m.state) || length(first(m.state)) length(x)
m.state = m.m.model.state
m.state = batchone.(m.m.model.state)
end
output = runmodel(m.m, m.state..., x)
m.state, output = output[1:end-1], output[end]