From 19cf3e2b628ca7629a51c22c820bee2b2f14a68f Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 19 Apr 2017 17:33:55 +0100 Subject: [PATCH] split out runseq --- src/model.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/model.jl b/src/model.jl index 4798f07d..61b4fdc0 100644 --- a/src/model.jl +++ b/src/model.jl @@ -146,10 +146,16 @@ struct SeqModel <: Model steps::Int end -(m::SeqModel)(x::Tuple) = m.model(x) +runseq(f, xs::Tuple...) = f(xs...) +runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2) +runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...)) -splitseq(xs) = unstack(rawbatch(xs), 2) -joinseq(xs) = rebatchseq(stack(xs, 2)) +function (m::SeqModel)(x) + runseq(x) do x + @assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))" + m.model(x) + end +end (m::SeqModel)(x::AbstractArray) = stack(m((unstack(x, 2)...,)), 2) (m::SeqModel)(x::BatchSeq) = rebatchseq(m(rawbatch(x)))