split out runseq
This commit is contained in:
parent
edfb0211e6
commit
19cf3e2b62
12
src/model.jl
12
src/model.jl
@ -146,10 +146,16 @@ struct SeqModel <: Model
|
|||||||
steps::Int
|
steps::Int
|
||||||
end
|
end
|
||||||
|
|
||||||
(m::SeqModel)(x::Tuple) = m.model(x)
|
runseq(f, xs::Tuple...) = f(xs...)
|
||||||
|
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
|
||||||
|
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
|
||||||
|
|
||||||
splitseq(xs) = unstack(rawbatch(xs), 2)
|
function (m::SeqModel)(x)
|
||||||
joinseq(xs) = rebatchseq(stack(xs, 2))
|
runseq(x) do x
|
||||||
|
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
|
||||||
|
m.model(x)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
(m::SeqModel)(x::AbstractArray) = stack(m((unstack(x, 2)...,)), 2)
|
(m::SeqModel)(x::AbstractArray) = stack(m((unstack(x, 2)...,)), 2)
|
||||||
(m::SeqModel)(x::BatchSeq) = rebatchseq(m(rawbatch(x)))
|
(m::SeqModel)(x::BatchSeq) = rebatchseq(m(rawbatch(x)))
|
||||||
|
Loading…
Reference in New Issue
Block a user