batch support
This commit is contained in:
parent
b8f148b012
commit
d592f4e327
@ -150,7 +150,15 @@ function hDesc(h::CuArray)
|
||||
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
||||
end
|
||||
|
||||
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; train = false) where T
|
||||
# TODO: can we just manipulate strides here?
|
||||
# TODO: should use repmat, but this isn't implemented.
|
||||
hBatch(x::AbstractVector, h::CuVector) = h
|
||||
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
|
||||
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
|
||||
|
||||
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing; train = false) where T
|
||||
h = hBatch(x, h_)
|
||||
c = c_ == nothing ? nothing : hBatch(x, c_)
|
||||
@assert size(x, 1) == rnn.input
|
||||
@assert size(h, 1) == rnn.hidden
|
||||
@assert size(x, 2) == size(h, 2)
|
||||
@ -273,7 +281,7 @@ function desc(rnn)
|
||||
return d
|
||||
end
|
||||
|
||||
import Flux.Tracker: data, isleaf, istracked, track, back_, @back
|
||||
import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast
|
||||
|
||||
struct RNNCall{R}
|
||||
rnn::R
|
||||
@ -318,10 +326,11 @@ end
|
||||
function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h)
|
||||
y, ho = y_
|
||||
dy, dho = Δ
|
||||
dx, dh = backwardData(descs[m.rnn], y, dy, dho, data(h))
|
||||
h_ = hBatch(x, data(h))
|
||||
dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_)
|
||||
@back(x, dx)
|
||||
@back(h, dh)
|
||||
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), data(h), y)
|
||||
@back(h, unbroadcast(h, dh))
|
||||
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y)
|
||||
# We don't have to make this assumption, it's just slightly more complex.
|
||||
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||
@ -332,11 +341,13 @@ end
|
||||
function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c)
|
||||
y, ho, co = y_
|
||||
dy, dho, dco = Δ
|
||||
dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, data(h), data(c))
|
||||
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), data(h), y)
|
||||
h_ = hBatch(x, data(h))
|
||||
c_ = hBatch(x, data(c))
|
||||
dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_)
|
||||
@back(x, dx)
|
||||
@back(h, dh)
|
||||
@back(c, dc)
|
||||
@back(h, unbroadcast(h, dh))
|
||||
@back(c, unbroadcast(h, dc))
|
||||
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y)
|
||||
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
|
||||
|
Loading…
Reference in New Issue
Block a user