diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index c9728a48..5fc01df5 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -97,6 +97,16 @@ function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc) return Int(size[]) end +const workspace = [CuVector{UInt8}(1)] + +getworkspace(bytes) = + length(workspace[]) ≥ bytes ? + workspace[] : + (workspace[] = CuVector{UInt8}(bytes)) + +getworkspace(r::RNNDesc, seqlen, xdesc) = + getworkspace(rnnWorkspaceSize(rnn, seqlen, xdesc)) + function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc) size = Csize_t[0] @check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Void}, Ptr{Void}, Cint, Ptr{Ptr{Void}}, Ptr{Csize_t}), @@ -109,12 +119,14 @@ function getreserve(r::RNNDesc, seqlen, xdesc) sz ≤ length(r.reserve) ? r.reserve : (r.reserve = CuVector{UInt8}(sz)) end -function cudnnRNNForward(::Type{T}, rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, +function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, workspace, reserve=nothing; train = (reserve ≠ nothing)) where T if !train @check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t, (Ptr{Void}, Ptr{Void}, Cint, - Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, + Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, Ptr{Void}, Csize_t), libcudnn_handle[], rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, @@ -130,6 +142,8 @@ function cudnnRNNForward(::Type{T}, rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, end end +xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))] + hDesc(h::Void) = C_NULL, C_NULL function hDesc(h::CuArray) TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h @@ -140,26 +154,73 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; tra @assert size(h, 1) == rnn.hidden @assert size(x, 2) == size(h, 2) seqLength = 1 - xdesc = [TensorDesc(T, (1, size(x, 1), size(x, 2)))] + xdesc = xDesc(x) y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2)) - ydesc = [TensorDesc(T, (1, size(y, 1), size(y, 2)))] - workspace = CuVector{UInt8}(rnnWorkspaceSize(rnn, seqLength, xdesc)) # TODO: reuse this + ydesc = xDesc(y) + workspace = getworkspace(rnn, seqLength, xdesc) reserve = train ? getreserve(rnn, seqLength, xdesc) : rnn.reserve - cy = c == nothing ? c : similar(c) - cudnnRNNForward(T, rnn, seqLength, + co = c == nothing ? c : similar(c) + cudnnRNNForward(rnn, seqLength, xdesc, x, hDesc(h)..., hDesc(c)..., - TensorDesc(T, (1, 1, length(rnn.params))), rnn.params, + FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, ydesc, y, C_NULL, C_NULL, # hout - hDesc(cy)..., + hDesc(co)..., workspace, reserve, train = train) - if c == nothing - return y, y - else - return y, y, cy - end + return c == nothing ? (y, y) : (y, y, co) +end + +function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, + wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T + @check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t, + (Ptr{Void}, Ptr{Void}, Cint, + Ptr{Ptr{Void}}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, + Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, + Ptr{Void}, Csize_t, Ptr{Void}, Csize_t), + libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, + wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs)) +end + +function backwardData(rnn::RNNDesc{T}, y, dy, dho, dco, h, c) where T + yd = xDesc(y) + dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2)) + dh = similar(h) + dc = c == nothing ? nothing : similar(c) + cudnnRNNBackwardData(rnn, 1, + yd, y, yd, dy, hDesc(dho)..., hDesc(dco)..., + FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, + hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)..., + workspace[], rnn.reserve) + return c == nothing ? (dx, dh) : (dx, dh, dc) +end + +backwardData(rnn, y, dy, dho, hx) = + backwardData(rnn, y, dy, dho, nothing, hx, nothing) + +function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw, + workspace, reserve) where T + @check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t, + (Ptr{Void}, Ptr{Void}, Cint, # handle, rnnDesc, seqLength + Ptr{Ptr{Void}}, Ptr{T}, #x + Ptr{Void}, Ptr{T}, #hx + Ptr{Ptr{Void}}, Ptr{T}, #y + Ptr{Void}, Csize_t, #ws + Ptr{Void}, Ptr{T}, #dw + Ptr{Void}, Csize_t), #rs + libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y, + workspace, length(workspace), dwd, dw, reserve, length(reserve)) +end + +function backwardWeights(rnn::RNNDesc{T}, x, h, y) where T + dw = zeros(rnn.params) + cudnnRNNBackwardWeights(rnn, 1, + xDesc(x), x, hDesc(h)..., xDesc(y), y, + FilterDesc(T, (1, 1, length(dw))), dw, + workspace[], rnn.reserve) + return params(dw, rnn.input, rnn.hidden) end # Interface @@ -194,7 +255,7 @@ function copyparams!(m::CuRNNs, d::RNNDesc) return end -function RNNDesc(m::CuRNNs{T}) where {T} +function RNNDesc(m::CuRNNs{T}) where T h, i = length(m.h), size(m.Wi, 2) mode = m isa CuRNN ? (m.σ == tanh ? RNN_TANH : RNN_RELU) :