split out runseq
This commit is contained in:
parent
edfb0211e6
commit
19cf3e2b62
12
src/model.jl
12
src/model.jl
@ -146,10 +146,16 @@ struct SeqModel <: Model
|
||||
steps::Int
|
||||
end
|
||||
|
||||
(m::SeqModel)(x::Tuple) = m.model(x)
|
||||
runseq(f, xs::Tuple...) = f(xs...)
|
||||
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
|
||||
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
|
||||
|
||||
splitseq(xs) = unstack(rawbatch(xs), 2)
|
||||
joinseq(xs) = rebatchseq(stack(xs, 2))
|
||||
function (m::SeqModel)(x)
|
||||
runseq(x) do x
|
||||
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
|
||||
m.model(x)
|
||||
end
|
||||
end
|
||||
|
||||
(m::SeqModel)(x::AbstractArray) = stack(m((unstack(x, 2)...,)), 2)
|
||||
(m::SeqModel)(x::BatchSeq) = rebatchseq(m(rawbatch(x)))
|
||||
|
Loading…
Reference in New Issue
Block a user