split out runseq

This commit is contained in:
Mike J Innes 2017-04-19 17:33:55 +01:00
parent edfb0211e6
commit 19cf3e2b62

View File

@ -146,10 +146,16 @@ struct SeqModel <: Model
steps::Int steps::Int
end end
(m::SeqModel)(x::Tuple) = m.model(x) runseq(f, xs::Tuple...) = f(xs...)
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
splitseq(xs) = unstack(rawbatch(xs), 2) function (m::SeqModel)(x)
joinseq(xs) = rebatchseq(stack(xs, 2)) runseq(x) do x
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
m.model(x)
end
end
(m::SeqModel)(x::AbstractArray) = stack(m((unstack(x, 2)...,)), 2) (m::SeqModel)(x::AbstractArray) = stack(m((unstack(x, 2)...,)), 2)
(m::SeqModel)(x::BatchSeq) = rebatchseq(m(rawbatch(x))) (m::SeqModel)(x::BatchSeq) = rebatchseq(m(rawbatch(x)))