diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 972163f0..356d7237 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -150,7 +150,15 @@ function hDesc(h::CuArray) TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h end -function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; train = false) where T +# TODO: can we just manipulate strides here? +# TODO: should use repmat, but this isn't implemented. +hBatch(x::AbstractVector, h::CuVector) = h +hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2)) +hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1) + +function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing; train = false) where T + h = hBatch(x, h_) + c = c_ == nothing ? nothing : hBatch(x, c_) @assert size(x, 1) == rnn.input @assert size(h, 1) == rnn.hidden @assert size(x, 2) == size(h, 2) @@ -273,7 +281,7 @@ function desc(rnn) return d end -import Flux.Tracker: data, isleaf, istracked, track, back_, @back +import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast struct RNNCall{R} rnn::R @@ -318,10 +326,11 @@ end function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h) y, ho = y_ dy, dho = Δ - dx, dh = backwardData(descs[m.rnn], y, dy, dho, data(h)) + h_ = hBatch(x, data(h)) + dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_) @back(x, dx) - @back(h, dh) - (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), data(h), y) + @back(h, unbroadcast(h, dh)) + (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y) # We don't have to make this assumption, it's just slightly more complex. @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) @@ -332,11 +341,13 @@ end function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c) y, ho, co = y_ dy, dho, dco = Δ - dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, data(h), data(c)) - (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), data(h), y) + h_ = hBatch(x, data(h)) + c_ = hBatch(x, data(c)) + dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_) @back(x, dx) - @back(h, dh) - @back(c, dc) + @back(h, unbroadcast(h, dh)) + @back(c, unbroadcast(h, dc)) + (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y) @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)