diff --git a/src/model.jl b/src/model.jl index 4798f07d..61b4fdc0 100644 --- a/src/model.jl +++ b/src/model.jl @@ -146,10 +146,16 @@ struct SeqModel <: Model steps::Int end -(m::SeqModel)(x::Tuple) = m.model(x) +runseq(f, xs::Tuple...) = f(xs...) +runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2) +runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...)) -splitseq(xs) = unstack(rawbatch(xs), 2) -joinseq(xs) = rebatchseq(stack(xs, 2)) +function (m::SeqModel)(x) + runseq(x) do x + @assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))" + m.model(x) + end +end (m::SeqModel)(x::AbstractArray) = stack(m((unstack(x, 2)...,)), 2) (m::SeqModel)(x::BatchSeq) = rebatchseq(m(rawbatch(x)))