Merge remote-tracking branch 'upstream/master' into add-more-tf-ops-2
This commit is contained in:
commit
5143410313
@ -29,19 +29,30 @@ struct SeqModel
|
||||
steps::Int
|
||||
end
|
||||
|
||||
runseq(f, xs::Tuple...) = f(xs...)
|
||||
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
|
||||
runseq(f, xs::Batch{<:Seq}...) = convert(Batch{Seq}, runseq(f, rawbatch.(xs)...))
|
||||
runseq(f, xs) = runseq(f, (xs...,))
|
||||
seqtuple(x, n) = x
|
||||
seqtuple(xs::Tuple, n) = seqtuple.(xs, n)
|
||||
|
||||
function (m::SeqModel)(x)
|
||||
runseq(x) do x
|
||||
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
|
||||
m.model(x)
|
||||
end
|
||||
seqtuple(xs::AbstractArray, n) =
|
||||
ndims(xs) < 3 ? xs :
|
||||
n ≠ 0 && size(xs, 2) ≠ n ? error("Expecting sequence length $n, got $(size(xs, 2))") :
|
||||
(unstack(xs, 2)...)
|
||||
|
||||
seqtuple(xs::Batch{<:Seq}, n) = seqtuple(rawbatch(xs), n)
|
||||
|
||||
reseq(x) = x
|
||||
reseq(x::Tuple{}) = ()
|
||||
reseq(xs::Tuple) = all(isa.(xs, AbstractArray) .& (ndims.(xs) .≥ 2)) ? stack(xs, 2) : reseq.(xs)
|
||||
|
||||
function (m::SeqModel)(xs...)
|
||||
xs = seqtuple(xs, m.steps)
|
||||
reseq(m.model(xs...))
|
||||
end
|
||||
|
||||
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)
|
||||
function back!(m::SeqModel, args...)
|
||||
args = seqtuple(args, 0)
|
||||
# TODO: reseq
|
||||
back!(m.model, args...)
|
||||
end
|
||||
|
||||
update!(m::SeqModel, η) = update!(m.model, η)
|
||||
|
||||
|
@ -30,7 +30,7 @@ function train!(m, train; cb = [],
|
||||
@progress for e in 1:epoch
|
||||
info("Epoch $e")
|
||||
@cb for (x, y) in train
|
||||
x, y = tobatch.((x, y))
|
||||
x, y = mapt(tobatch, (x, y))
|
||||
ŷ = m(x)
|
||||
any(isnan, ŷ) && error("NaN")
|
||||
Δ = back!(loss, 1, ŷ, y)
|
||||
|
Loading…
Reference in New Issue
Block a user