batch support

This commit is contained in:
Mike J Innes 2018-02-08 01:06:08 +00:00
parent b8f148b012
commit d592f4e327

View File

@ -150,7 +150,15 @@ function hDesc(h::CuArray)
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
end
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; train = false) where T
# TODO: can we just manipulate strides here?
# TODO: should use repmat, but this isn't implemented.
hBatch(x::AbstractVector, h::CuVector) = h
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing; train = false) where T
h = hBatch(x, h_)
c = c_ == nothing ? nothing : hBatch(x, c_)
@assert size(x, 1) == rnn.input
@assert size(h, 1) == rnn.hidden
@assert size(x, 2) == size(h, 2)
@ -273,7 +281,7 @@ function desc(rnn)
return d
end
import Flux.Tracker: data, isleaf, istracked, track, back_, @back
import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast
struct RNNCall{R}
rnn::R
@ -318,10 +326,11 @@ end
function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h)
y, ho = y_
dy, dho = Δ
dx, dh = backwardData(descs[m.rnn], y, dy, dho, data(h))
h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_)
@back(x, dx)
@back(h, dh)
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), data(h), y)
@back(h, unbroadcast(h, dh))
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y)
# We don't have to make this assumption, it's just slightly more complex.
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
@ -332,11 +341,13 @@ end
function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c)
y, ho, co = y_
dy, dho, dco = Δ
dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, data(h), data(c))
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), data(h), y)
h_ = hBatch(x, data(h))
c_ = hBatch(x, data(c))
dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_)
@back(x, dx)
@back(h, dh)
@back(c, dc)
@back(h, unbroadcast(h, dh))
@back(c, unbroadcast(h, dc))
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y)
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)