diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 5fc01df5..972163f0 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -105,7 +105,7 @@ getworkspace(bytes) = (workspace[] = CuVector{UInt8}(bytes)) getworkspace(r::RNNDesc, seqlen, xdesc) = - getworkspace(rnnWorkspaceSize(rnn, seqlen, xdesc)) + getworkspace(rnnWorkspaceSize(r, seqlen, xdesc)) function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc) size = Csize_t[0] @@ -145,6 +145,7 @@ end xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))] hDesc(h::Void) = C_NULL, C_NULL +hDesc(x::Integer) = (@assert x == 0; hDesc(nothing)) function hDesc(h::CuArray) TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h end @@ -272,19 +273,72 @@ function desc(rnn) return d end +import Flux.Tracker: data, isleaf, istracked, track, back_, @back + +struct RNNCall{R} + rnn::R +end + +(c::RNNCall)(args...) = forward(desc(c.rnn), args..., train = true) + istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...)) function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} - y, h = forward(desc(m), Flux.data(x), Flux.data(h), train = istrain(m, h, x)) - return h, y + result = istrain(m, h, x) ? + track(RNNCall(m), x, h) : + forward(desc(m), x, h) + return result[2], result[1] end function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} - y, h = forward(desc(m), Flux.data(x), Flux.data(h), train = istrain(m, h, x)) - return h, y + result = istrain(m, h, x) ? + track(RNNCall(m), x, h) : + forward(desc(m), x, h) + return result[2], result[1] end function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64} - y, h, c = forward(desc(m), Flux.data(x), Flux.data.(h)..., train = istrain(m, h, x)) - return (h, c), y + result = istrain(m, h, x) ? + track(RNNCall(m), x, h[1], h[2]) : + forward(desc(m), x, h[1], h[2]) + return (result[2], result[3]), result[1] +end + +function accum_transpose!(dst::CuArray, src::CuArray) + function kernel(dst, src) + I = @cuindex dst + dst[I...] += src[reverse(I)...] + return + end + blk, thr = cudims(dst) + @cuda (blk, thr) kernel(dst, src) + return dst +end + +function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h) + y, ho = y_ + dy, dho = Δ + dx, dh = backwardData(descs[m.rnn], y, dy, dho, data(h)) + @back(x, dx) + @back(h, dh) + (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), data(h), y) + # We don't have to make this assumption, it's just slightly more complex. + @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) + istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) + istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh) + istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db) +end + +function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c) + y, ho, co = y_ + dy, dho, dco = Δ + dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, data(h), data(c)) + (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), data(h), y) + @back(x, dx) + @back(h, dh) + @back(c, dc) + @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) + istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) + istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh) + istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db) end