diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index 33f38cdd..ce4bd5ed 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -29,19 +29,30 @@ struct SeqModel steps::Int end -runseq(f, xs::Tuple...) = f(xs...) -runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2) -runseq(f, xs::Batch{<:Seq}...) = convert(Batch{Seq}, runseq(f, rawbatch.(xs)...)) -runseq(f, xs) = runseq(f, (xs...,)) +seqtuple(x, n) = x +seqtuple(xs::Tuple, n) = seqtuple.(xs, n) -function (m::SeqModel)(x) - runseq(x) do x - @assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))" - m.model(x) - end +seqtuple(xs::AbstractArray, n) = + ndims(xs) < 3 ? xs : + n ≠ 0 && size(xs, 2) ≠ n ? error("Expecting sequence length $n, got $(size(xs, 2))") : + (unstack(xs, 2)...) + +seqtuple(xs::Batch{<:Seq}, n) = seqtuple(rawbatch(xs), n) + +reseq(x) = x +reseq(x::Tuple{}) = () +reseq(xs::Tuple) = all(isa.(xs, AbstractArray) .& (ndims.(xs) .≥ 2)) ? stack(xs, 2) : reseq.(xs) + +function (m::SeqModel)(xs...) + xs = seqtuple(xs, m.steps) + reseq(m.model(xs...)) end -back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),) +function back!(m::SeqModel, args...) + args = seqtuple(args, 0) + # TODO: reseq + back!(m.model, args...) +end update!(m::SeqModel, η) = update!(m.model, η) diff --git a/src/training.jl b/src/training.jl index 70088444..62a34c46 100644 --- a/src/training.jl +++ b/src/training.jl @@ -30,7 +30,7 @@ function train!(m, train; cb = [], @progress for e in 1:epoch info("Epoch $e") @cb for (x, y) in train - x, y = tobatch.((x, y)) + x, y = mapt(tobatch, (x, y)) ŷ = m(x) any(isnan, ŷ) && error("NaN") Δ = back!(loss, 1, ŷ, y)