diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index 028a0f8b..00f0d0f2 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -3,6 +3,7 @@ module CUDA using ..CuArrays if CuArrays.libcudnn !== nothing # TODO: use CuArrays.has_cudnn() + using CuArrays: CUDNN include("curnn.jl") include("cudnn.jl") else diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 448ea140..aa16f926 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -1,7 +1,6 @@ -using CuArrays: libcudnn -using CuArrays.CUDNN: @check, handle, cudnnStatus_t, cudnnTensorDescriptor_t, - cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc -import CuArrays.CUDAdrv: CuPtr, CU_NULL +using CuArrays.CUDNN: handle, TensorDesc, FilterDesc + +import CuArrays.CUDAdrv: CU_NULL using LinearAlgebra @@ -15,22 +14,17 @@ Base.unsafe_convert(::Type{Ptr{Nothing}}, dd::DropoutDesc) = dd.ptr function DropoutDesc(ρ::Real; seed::Integer=0) d = [C_NULL] s = Csize_t[0] - @check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Nothing}},), d) - @check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s) + CUDNN.cudnnCreateDropoutDescriptor(d) + CUDNN.cudnnDropoutGetStatesSize(handle(), s) states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0? desc = DropoutDesc(d[], states) - @check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,CuPtr{Nothing},Csize_t,Culonglong), - desc,handle(),ρ,states,length(states),seed) + CUDNN.cudnnSetDropoutDescriptor(desc, handle(), ρ, states, length(states), seed) finalizer(desc) do x - @check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x) + CUDNN.cudnnDestroyDropoutDescriptor(x) end return desc end -const BATCHNORM_SPATIAL = 1 -const BATCHNORM_ACTIVATION = 0 -const BATCHNORM_MIN_EPS = 1e-5 - @inline _wsize(y) = (map(_ -> 1, size(y)[1:end-2])..., size(y)[end-1], 1) @inline _reddims(y) = (collect(1:ndims(y)-2)..., ndims(y)) @@ -67,9 +61,9 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray alpha = T(1), beta = T(0), eps = T(1e-5), training = true) where T<:Union{Float32, Float64} dims = _wsize(x) - if eps < BATCHNORM_MIN_EPS - # warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS) - eps = BATCHNORM_MIN_EPS + if eps < CUDNN.CUDNN_BN_MIN_EPSILON + # warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", CUDNN.CUDNN_BN_MIN_EPSILON) + eps = CUDNN.CUDNN_BN_MIN_EPSILON end xd = TensorDesc(x) yd = TensorDesc(y) @@ -85,42 +79,14 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray ivar = CU_NULL end - @check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t, - (cudnnHandle_t,cudnnBatchNormMode_t, - Ptr{T}, Ptr{T}, - Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, CuPtr{T}, - Cdouble, CuPtr{T}, CuPtr{T}, - Cdouble, CuPtr{T}, CuPtr{T}), - handle(), BATCHNORM_SPATIAL, - Ref(T(alpha)), Ref(T(beta)), - xd, x, - yd, y, - gd, g, b, - momentum, running_mean, running_var, - eps, mean, ivar) + CUDNN.cudnnBatchNormalizationForwardTraining(handle(), CUDNN.CUDNN_BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), xd, x, yd, y, gd, g, b, momentum, running_mean, running_var, eps, mean, ivar) if cache !== nothing cache.mean = mean cache.ivar = ivar end else - @check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t, - (Ptr{cudnnHandle_t},cudnnBatchNormMode_t, - Ptr{T}, Ptr{T}, - Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, CuPtr{T}, - CuPtr{T}, CuPtr{T}, - Cdouble), - handle(), BATCHNORM_SPATIAL, - Ref(T(alpha)), Ref(T(beta)), - xd, x, - yd, y, - gd, g, b, - running_mean, running_var, - eps) + CUDNN.cudnnBatchNormalizationForwardInference(handle(), CUDNN.CUDNN_BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), xd, x, yd, y, gd, g, b, running_mean, running_var, eps) end end @@ -164,27 +130,11 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, mean, ivar = CU_NULL, CU_NULL end - if eps < BATCHNORM_MIN_EPS - eps = BATCHNORM_MIN_EPS + if eps < CUDNN.CUDNN_BN_MIN_EPSILON + eps = CUDNN.CUDNN_BN_MIN_EPSILON end - @check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t, - (cudnnHandle_t,cudnnBatchNormMode_t, - Ptr{T}, Ptr{T}, - Ptr{T}, Ptr{T}, - Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, CuPtr{T}, CuPtr{T}, - Cdouble, CuPtr{T}, CuPtr{T}), - handle(), BATCHNORM_SPATIAL, - Ref(T(alpha)), Ref(T(beta)), - Ref(T(dalpha)), Ref(T(dbeta)), - xd, x, - dyd, dy, - dxd, dx, - gd, g, dg, db, - eps, mean, ivar) + CUDNN.cudnnBatchNormalizationBackward(handle(), CUDNN.CUDNN_BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), Ref(T(dalpha)), Ref(T(dbeta)), xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar) else ivar = 1 ./ sqrt.(reshape(running_var, _wsize(x)) .+ eps) dx .= dy .* reshape(g, _wsize(x)) .* ivar diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index ca8b5140..c37d031c 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -1,8 +1,6 @@ -using CuArrays: libcudnn -using CuArrays.CUDNN: @check, cudnnStatus_t, cudnnTensorDescriptor_t, - cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc +using CuArrays.CUDNN: handle, cudnnDataType, TensorDesc, FilterDesc -import CuArrays.CUDAdrv: CuPtr, CU_NULL +import CuArrays.CUDAdrv: CU_NULL using LinearAlgebra @@ -48,8 +46,7 @@ Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr function rnnParamSize(T, r, input) size = Csize_t[0] - @check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint), - handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T)) + CUDNN.cudnnGetRNNParamsSize(handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T)) return Int(size[])÷sizeof(T) end @@ -58,28 +55,26 @@ ngates(r::RNNDesc) = ngates(r.mode) function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T d = [C_NULL] - @check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Nothing}},),d) + CUDNN.cudnnCreateRNNDescriptor(d) dropoutDesc = DropoutDesc(0) inputMode = LINEAR_INPUT direction = UNIDIRECTIONAL algo = RNN_ALGO_STANDARD - @check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint), - handle(),d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T)) + CUDNN.cudnnSetRNNDescriptor_v6(handle(),d[],hidden,layers,dropoutDesc,CUDNN.cudnnRNNInputMode_t(inputMode),CUDNN.cudnnDirectionMode_t(direction),CUDNN.cudnnRNNMode_t(mode),CUDNN.cudnnRNNAlgo_t(algo),cudnnDataType(T)) - w = CuArrays.zeros(T, rnnParamSize(T, d[], input)) + w =CuArrays.zeros(T, rnnParamSize(T, d[], input)) # TODO: avoid reserve allocation here rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[]) finalizer(rd) do x - @check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x) + CUDNN.cudnnDestroyRNNDescriptor(x) end return rd end function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc) size = Csize_t[0] - @check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}), - handle(), r, seqlen, xdesc, size) + CUDNN.cudnnGetRNNWorkspaceSize(handle(), r, seqlen, xdesc, size) return Int(size[]) end @@ -95,31 +90,18 @@ getworkspace(r::RNNDesc, seqlen, xdesc) = function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc) size = Csize_t[0] - @check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}), - handle(), r, seqlen, xdesc, size) + CUDNN.cudnnGetRNNTrainingReserveSize(handle(), r, seqlen, xdesc, size) return Int(size[]) end function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, workspace, reserve=nothing) where T if reserve == nothing - @check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t, - (Ptr{Nothing}, Ptr{Nothing}, Cint, - Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, - CuPtr{Nothing}, Csize_t), - handle(), rnn, seqlen, - xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, - workspace, length(workspace)) + CUDNN.cudnnRNNForwardInference(handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, + hod, ho, cod, co, workspace, length(workspace)) else - @check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t, - (Ptr{Nothing}, Ptr{Nothing}, Cint, - Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, - CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t), - handle(), rnn, seqlen, - xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, - workspace, length(workspace), reserve, length(reserve)) + CUDNN.cudnnRNNForwardTraining(handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, + hod, ho, cod, co, workspace, length(workspace), reserve, length(reserve)) end end @@ -134,8 +116,8 @@ end # TODO: can we just manipulate strides here? # TODO: should use repmat, but this isn't implemented. hBatch(x::AbstractVector, h::CuVector) = h -hBatch(x::AbstractMatrix, h::CuVector) = h .* CuArrays.ones(1, size(x, 2)) -hBatch(x::AbstractMatrix, h::CuMatrix) = h .* CuArrays.ones(1, size(h,2) == 1 ? size(x,2) : 1) +hBatch(x::AbstractMatrix, h::CuVector) = h .*CuArrays.ones(1, size(x, 2)) +hBatch(x::AbstractMatrix, h::CuMatrix) = h .*CuArrays.ones(1, size(h,2) == 1 ? size(x,2) : 1) function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T h = hBatch(x, h_) @@ -169,18 +151,6 @@ end forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T = forward(rnn, x, h, c, Val{true}) -function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, - wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T - @check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t, - (Ptr{Nothing}, Ptr{Nothing}, Cint, - Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, - CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, - CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t), - handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, - wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs)) -end - function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T # Same as above, any more efficient way? dy = dy_ isa Integer ? zero(y) : dy_ @@ -188,37 +158,24 @@ function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2)) dh = similar(h) dc = c == nothing ? nothing : similar(c) - cudnnRNNBackwardData(rnn, 1, + CUDNN.cudnnRNNBackwardData(handle(), rnn, 1, yd, y, yd, dy, hDesc(dho)..., hDesc(dco)..., FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)..., - workspace[], reserve) + workspace[], length(workspace[]), reserve, length(reserve)) return c == nothing ? (dx, dh) : (dx, dh, dc) end backwardData(rnn, y, dy, dho, hx, reserve) = backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve) -function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw, - workspace, reserve) where T - @check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t, - (Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength - Ptr{Ptr{Nothing}}, CuPtr{T}, #x - Ptr{Nothing}, CuPtr{T}, #hx - Ptr{Ptr{Nothing}}, CuPtr{T}, #y - CuPtr{Nothing}, Csize_t, #ws - Ptr{Nothing}, CuPtr{T}, #dw - CuPtr{Nothing}, Csize_t), #rs - handle(), rnn, seqlen, xd, x, hd, h, yd, y, - workspace, length(workspace), dwd, dw, reserve, length(reserve)) -end - function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T dw = zero(rnn.params) - cudnnRNNBackwardWeights(rnn, 1, + CUDNN.cudnnRNNBackwardWeights(handle(), rnn, 1, xDesc(x), x, hDesc(h)..., xDesc(y), y, + workspace[], length(workspace[]), FilterDesc(T, (1, 1, length(dw))), dw, - workspace[], reserve) + reserve, length(reserve)) return params(dw, rnn.input, rnn.hidden, ngates(rnn)) end