Merge remote-tracking branch 'upstream/master' into add-more-tf-ops-2

This commit is contained in:
Ali Hamdi 2017-06-10 11:43:27 +02:00
commit 5143410313
2 changed files with 22 additions and 11 deletions

View File

@ -29,19 +29,30 @@ struct SeqModel
steps::Int
end
runseq(f, xs::Tuple...) = f(xs...)
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
runseq(f, xs::Batch{<:Seq}...) = convert(Batch{Seq}, runseq(f, rawbatch.(xs)...))
runseq(f, xs) = runseq(f, (xs...,))
seqtuple(x, n) = x
seqtuple(xs::Tuple, n) = seqtuple.(xs, n)
function (m::SeqModel)(x)
runseq(x) do x
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
m.model(x)
end
seqtuple(xs::AbstractArray, n) =
ndims(xs) < 3 ? xs :
n 0 && size(xs, 2) n ? error("Expecting sequence length $n, got $(size(xs, 2))") :
(unstack(xs, 2)...)
seqtuple(xs::Batch{<:Seq}, n) = seqtuple(rawbatch(xs), n)
reseq(x) = x
reseq(x::Tuple{}) = ()
reseq(xs::Tuple) = all(isa.(xs, AbstractArray) .& (ndims.(xs) .≥ 2)) ? stack(xs, 2) : reseq.(xs)
function (m::SeqModel)(xs...)
xs = seqtuple(xs, m.steps)
reseq(m.model(xs...))
end
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)
function back!(m::SeqModel, args...)
args = seqtuple(args, 0)
# TODO: reseq
back!(m.model, args...)
end
update!(m::SeqModel, η) = update!(m.model, η)

View File

@ -30,7 +30,7 @@ function train!(m, train; cb = [],
@progress for e in 1:epoch
info("Epoch $e")
@cb for (x, y) in train
x, y = tobatch.((x, y))
x, y = mapt(tobatch, (x, y))
= m(x)
any(isnan, ) && error("NaN")
Δ = back!(loss, 1, , y)