From 04fce70019ee59a9ae8050ec8d683670f12e5942 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 29 Aug 2019 16:34:35 +0200 Subject: [PATCH 1/8] Move low-level CUDNN wrappers to CuArrays. --- src/cuda/cuda.jl | 1 + src/cuda/cudnn.jl | 80 +++++++++------------------------------------ src/cuda/curnn.jl | 83 ++++++++++++----------------------------------- 3 files changed, 36 insertions(+), 128 deletions(-) diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index 028a0f8b..00f0d0f2 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -3,6 +3,7 @@ module CUDA using ..CuArrays if CuArrays.libcudnn !== nothing # TODO: use CuArrays.has_cudnn() + using CuArrays: CUDNN include("curnn.jl") include("cudnn.jl") else diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 448ea140..aa16f926 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -1,7 +1,6 @@ -using CuArrays: libcudnn -using CuArrays.CUDNN: @check, handle, cudnnStatus_t, cudnnTensorDescriptor_t, - cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc -import CuArrays.CUDAdrv: CuPtr, CU_NULL +using CuArrays.CUDNN: handle, TensorDesc, FilterDesc + +import CuArrays.CUDAdrv: CU_NULL using LinearAlgebra @@ -15,22 +14,17 @@ Base.unsafe_convert(::Type{Ptr{Nothing}}, dd::DropoutDesc) = dd.ptr function DropoutDesc(ρ::Real; seed::Integer=0) d = [C_NULL] s = Csize_t[0] - @check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Nothing}},), d) - @check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s) + CUDNN.cudnnCreateDropoutDescriptor(d) + CUDNN.cudnnDropoutGetStatesSize(handle(), s) states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0? desc = DropoutDesc(d[], states) - @check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,CuPtr{Nothing},Csize_t,Culonglong), - desc,handle(),ρ,states,length(states),seed) + CUDNN.cudnnSetDropoutDescriptor(desc, handle(), ρ, states, length(states), seed) finalizer(desc) do x - @check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x) + CUDNN.cudnnDestroyDropoutDescriptor(x) end return desc end -const BATCHNORM_SPATIAL = 1 -const BATCHNORM_ACTIVATION = 0 -const BATCHNORM_MIN_EPS = 1e-5 - @inline _wsize(y) = (map(_ -> 1, size(y)[1:end-2])..., size(y)[end-1], 1) @inline _reddims(y) = (collect(1:ndims(y)-2)..., ndims(y)) @@ -67,9 +61,9 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray alpha = T(1), beta = T(0), eps = T(1e-5), training = true) where T<:Union{Float32, Float64} dims = _wsize(x) - if eps < BATCHNORM_MIN_EPS - # warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS) - eps = BATCHNORM_MIN_EPS + if eps < CUDNN.CUDNN_BN_MIN_EPSILON + # warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", CUDNN.CUDNN_BN_MIN_EPSILON) + eps = CUDNN.CUDNN_BN_MIN_EPSILON end xd = TensorDesc(x) yd = TensorDesc(y) @@ -85,42 +79,14 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray ivar = CU_NULL end - @check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t, - (cudnnHandle_t,cudnnBatchNormMode_t, - Ptr{T}, Ptr{T}, - Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, CuPtr{T}, - Cdouble, CuPtr{T}, CuPtr{T}, - Cdouble, CuPtr{T}, CuPtr{T}), - handle(), BATCHNORM_SPATIAL, - Ref(T(alpha)), Ref(T(beta)), - xd, x, - yd, y, - gd, g, b, - momentum, running_mean, running_var, - eps, mean, ivar) + CUDNN.cudnnBatchNormalizationForwardTraining(handle(), CUDNN.CUDNN_BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), xd, x, yd, y, gd, g, b, momentum, running_mean, running_var, eps, mean, ivar) if cache !== nothing cache.mean = mean cache.ivar = ivar end else - @check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t, - (Ptr{cudnnHandle_t},cudnnBatchNormMode_t, - Ptr{T}, Ptr{T}, - Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, CuPtr{T}, - CuPtr{T}, CuPtr{T}, - Cdouble), - handle(), BATCHNORM_SPATIAL, - Ref(T(alpha)), Ref(T(beta)), - xd, x, - yd, y, - gd, g, b, - running_mean, running_var, - eps) + CUDNN.cudnnBatchNormalizationForwardInference(handle(), CUDNN.CUDNN_BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), xd, x, yd, y, gd, g, b, running_mean, running_var, eps) end end @@ -164,27 +130,11 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, mean, ivar = CU_NULL, CU_NULL end - if eps < BATCHNORM_MIN_EPS - eps = BATCHNORM_MIN_EPS + if eps < CUDNN.CUDNN_BN_MIN_EPSILON + eps = CUDNN.CUDNN_BN_MIN_EPSILON end - @check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t, - (cudnnHandle_t,cudnnBatchNormMode_t, - Ptr{T}, Ptr{T}, - Ptr{T}, Ptr{T}, - Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, CuPtr{T}, CuPtr{T}, - Cdouble, CuPtr{T}, CuPtr{T}), - handle(), BATCHNORM_SPATIAL, - Ref(T(alpha)), Ref(T(beta)), - Ref(T(dalpha)), Ref(T(dbeta)), - xd, x, - dyd, dy, - dxd, dx, - gd, g, dg, db, - eps, mean, ivar) + CUDNN.cudnnBatchNormalizationBackward(handle(), CUDNN.CUDNN_BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), Ref(T(dalpha)), Ref(T(dbeta)), xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar) else ivar = 1 ./ sqrt.(reshape(running_var, _wsize(x)) .+ eps) dx .= dy .* reshape(g, _wsize(x)) .* ivar diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index ca8b5140..c37d031c 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -1,8 +1,6 @@ -using CuArrays: libcudnn -using CuArrays.CUDNN: @check, cudnnStatus_t, cudnnTensorDescriptor_t, - cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc +using CuArrays.CUDNN: handle, cudnnDataType, TensorDesc, FilterDesc -import CuArrays.CUDAdrv: CuPtr, CU_NULL +import CuArrays.CUDAdrv: CU_NULL using LinearAlgebra @@ -48,8 +46,7 @@ Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr function rnnParamSize(T, r, input) size = Csize_t[0] - @check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint), - handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T)) + CUDNN.cudnnGetRNNParamsSize(handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T)) return Int(size[])÷sizeof(T) end @@ -58,28 +55,26 @@ ngates(r::RNNDesc) = ngates(r.mode) function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T d = [C_NULL] - @check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Nothing}},),d) + CUDNN.cudnnCreateRNNDescriptor(d) dropoutDesc = DropoutDesc(0) inputMode = LINEAR_INPUT direction = UNIDIRECTIONAL algo = RNN_ALGO_STANDARD - @check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint), - handle(),d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T)) + CUDNN.cudnnSetRNNDescriptor_v6(handle(),d[],hidden,layers,dropoutDesc,CUDNN.cudnnRNNInputMode_t(inputMode),CUDNN.cudnnDirectionMode_t(direction),CUDNN.cudnnRNNMode_t(mode),CUDNN.cudnnRNNAlgo_t(algo),cudnnDataType(T)) - w = CuArrays.zeros(T, rnnParamSize(T, d[], input)) + w =CuArrays.zeros(T, rnnParamSize(T, d[], input)) # TODO: avoid reserve allocation here rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[]) finalizer(rd) do x - @check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x) + CUDNN.cudnnDestroyRNNDescriptor(x) end return rd end function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc) size = Csize_t[0] - @check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}), - handle(), r, seqlen, xdesc, size) + CUDNN.cudnnGetRNNWorkspaceSize(handle(), r, seqlen, xdesc, size) return Int(size[]) end @@ -95,31 +90,18 @@ getworkspace(r::RNNDesc, seqlen, xdesc) = function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc) size = Csize_t[0] - @check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}), - handle(), r, seqlen, xdesc, size) + CUDNN.cudnnGetRNNTrainingReserveSize(handle(), r, seqlen, xdesc, size) return Int(size[]) end function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, workspace, reserve=nothing) where T if reserve == nothing - @check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t, - (Ptr{Nothing}, Ptr{Nothing}, Cint, - Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, - CuPtr{Nothing}, Csize_t), - handle(), rnn, seqlen, - xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, - workspace, length(workspace)) + CUDNN.cudnnRNNForwardInference(handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, + hod, ho, cod, co, workspace, length(workspace)) else - @check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t, - (Ptr{Nothing}, Ptr{Nothing}, Cint, - Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, - CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t), - handle(), rnn, seqlen, - xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, - workspace, length(workspace), reserve, length(reserve)) + CUDNN.cudnnRNNForwardTraining(handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, + hod, ho, cod, co, workspace, length(workspace), reserve, length(reserve)) end end @@ -134,8 +116,8 @@ end # TODO: can we just manipulate strides here? # TODO: should use repmat, but this isn't implemented. hBatch(x::AbstractVector, h::CuVector) = h -hBatch(x::AbstractMatrix, h::CuVector) = h .* CuArrays.ones(1, size(x, 2)) -hBatch(x::AbstractMatrix, h::CuMatrix) = h .* CuArrays.ones(1, size(h,2) == 1 ? size(x,2) : 1) +hBatch(x::AbstractMatrix, h::CuVector) = h .*CuArrays.ones(1, size(x, 2)) +hBatch(x::AbstractMatrix, h::CuMatrix) = h .*CuArrays.ones(1, size(h,2) == 1 ? size(x,2) : 1) function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T h = hBatch(x, h_) @@ -169,18 +151,6 @@ end forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T = forward(rnn, x, h, c, Val{true}) -function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, - wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T - @check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t, - (Ptr{Nothing}, Ptr{Nothing}, Cint, - Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, - Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, - CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, - CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t), - handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, - wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs)) -end - function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T # Same as above, any more efficient way? dy = dy_ isa Integer ? zero(y) : dy_ @@ -188,37 +158,24 @@ function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2)) dh = similar(h) dc = c == nothing ? nothing : similar(c) - cudnnRNNBackwardData(rnn, 1, + CUDNN.cudnnRNNBackwardData(handle(), rnn, 1, yd, y, yd, dy, hDesc(dho)..., hDesc(dco)..., FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)..., - workspace[], reserve) + workspace[], length(workspace[]), reserve, length(reserve)) return c == nothing ? (dx, dh) : (dx, dh, dc) end backwardData(rnn, y, dy, dho, hx, reserve) = backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve) -function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw, - workspace, reserve) where T - @check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t, - (Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength - Ptr{Ptr{Nothing}}, CuPtr{T}, #x - Ptr{Nothing}, CuPtr{T}, #hx - Ptr{Ptr{Nothing}}, CuPtr{T}, #y - CuPtr{Nothing}, Csize_t, #ws - Ptr{Nothing}, CuPtr{T}, #dw - CuPtr{Nothing}, Csize_t), #rs - handle(), rnn, seqlen, xd, x, hd, h, yd, y, - workspace, length(workspace), dwd, dw, reserve, length(reserve)) -end - function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T dw = zero(rnn.params) - cudnnRNNBackwardWeights(rnn, 1, + CUDNN.cudnnRNNBackwardWeights(handle(), rnn, 1, xDesc(x), x, hDesc(h)..., xDesc(y), y, + workspace[], length(workspace[]), FilterDesc(T, (1, 1, length(dw))), dw, - workspace[], reserve) + reserve, length(reserve)) return params(dw, rnn.input, rnn.hidden, ngates(rnn)) end From 1e7ff4f65ddb6ee1eada1f9e960ade56593e89d9 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 29 Aug 2019 17:26:10 +0200 Subject: [PATCH 2/8] Query the worksize. --- src/cuda/curnn.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index c37d031c..bbd4e122 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -158,11 +158,12 @@ function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2)) dh = similar(h) dc = c == nothing ? nothing : similar(c) + workspace = getworkspace(rnn, 1, yd) CUDNN.cudnnRNNBackwardData(handle(), rnn, 1, yd, y, yd, dy, hDesc(dho)..., hDesc(dco)..., FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)..., - workspace[], length(workspace[]), reserve, length(reserve)) + workspace, length(workspace), reserve, length(reserve)) return c == nothing ? (dx, dh) : (dx, dh, dc) end From 4942d7fcfd405b7790c038e3e557015da38d8152 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 30 Aug 2019 08:39:51 +0200 Subject: [PATCH 3/8] Move functionality over to CuArrays. --- src/cuda/cudnn.jl | 148 +------------------------------ src/cuda/curnn.jl | 221 ++++------------------------------------------ 2 files changed, 21 insertions(+), 348 deletions(-) diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index aa16f926..d394182e 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -1,149 +1,5 @@ -using CuArrays.CUDNN: handle, TensorDesc, FilterDesc - -import CuArrays.CUDAdrv: CU_NULL - -using LinearAlgebra - -mutable struct DropoutDesc - ptr::Ptr{Nothing} - states::CuVector{UInt8} -end - -Base.unsafe_convert(::Type{Ptr{Nothing}}, dd::DropoutDesc) = dd.ptr - -function DropoutDesc(ρ::Real; seed::Integer=0) - d = [C_NULL] - s = Csize_t[0] - CUDNN.cudnnCreateDropoutDescriptor(d) - CUDNN.cudnnDropoutGetStatesSize(handle(), s) - states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0? - desc = DropoutDesc(d[], states) - CUDNN.cudnnSetDropoutDescriptor(desc, handle(), ρ, states, length(states), seed) - finalizer(desc) do x - CUDNN.cudnnDestroyDropoutDescriptor(x) - end - return desc -end - -@inline _wsize(y) = (map(_ -> 1, size(y)[1:end-2])..., size(y)[end-1], 1) - -@inline _reddims(y) = (collect(1:ndims(y)-2)..., ndims(y)) - -mutable struct BNCache - mean - ivar -end - -BNCache() = BNCache(nothing, nothing) - -# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations -# so reshape a 2D Tensor into 4D -batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, - running_mean::CuArray{T}, running_var::CuArray{T}, momentum; - cache = nothing, alpha = T(1), beta = T(0), - eps = T(1e-5), training = true) where T<:Union{Float32, Float64} = - dropdims(batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), running_mean, running_var, momentum, - cache = cache, alpha = alpha, beta = beta, eps = eps, training = training), dims = (1, 2)) - -function batchnorm(g::CuArray{T}, b::CuArray{T}, x::Union{CuArray{T, 4},CuArray{T,5}}, - running_mean::CuArray{T}, running_var::CuArray{T}, momentum; - cache = nothing, alpha = T(1), beta = T(0), - eps = T(1e-5), training = true) where T<:Union{Float32, Float64} - y = similar(x) - cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache, - alpha = alpha, beta = beta, eps = eps, training = training) - y -end - -function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, - running_mean::CuArray{T}, running_var::CuArray{T}, - momentum; cache = nothing, - alpha = T(1), beta = T(0), - eps = T(1e-5), training = true) where T<:Union{Float32, Float64} - dims = _wsize(x) - if eps < CUDNN.CUDNN_BN_MIN_EPSILON - # warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", CUDNN.CUDNN_BN_MIN_EPSILON) - eps = CUDNN.CUDNN_BN_MIN_EPSILON - end - xd = TensorDesc(x) - yd = TensorDesc(y) - gd = TensorDesc(T, dims) - - if training - - if cache !== nothing - mean = zeros(CuArray{T}, dims...) - ivar = ones(CuArray{T}, dims...) - else - mean = CU_NULL - ivar = CU_NULL - end - - CUDNN.cudnnBatchNormalizationForwardTraining(handle(), CUDNN.CUDNN_BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), xd, x, yd, y, gd, g, b, momentum, running_mean, running_var, eps, mean, ivar) - - if cache !== nothing - cache.mean = mean - cache.ivar = ivar - end - else - CUDNN.cudnnBatchNormalizationForwardInference(handle(), CUDNN.CUDNN_BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), xd, x, yd, y, gd, g, b, running_mean, running_var, eps) - end -end - -function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T, 2}, - running_mean::CuArray{T}, running_var::CuArray{T}, momentum; - cache = nothing, eps = T(1e-5), alpha = T(1), - beta = T(0), training = true) where T<:Union{Float32, Float64} - dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1), - size(dy, 2)), running_mean, running_var, momentum, cache = cache, eps = eps, - alpha = alpha, beta = beta, training = training) - (dg, db, dropdims(dx, dims = (1, 2))) -end - -function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T}, - running_mean::CuArray{T}, running_var::CuArray{T}, momentum; - cache = nothing, eps = T(1e-5), alpha = T(1), - beta = T(0), training = true) where T<:Union{Float32, Float64} - dg = similar(g) - db = similar(b) - dx = similar(x) - cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum), - training = training, cache = cache, eps = eps, alpha = alpha, beta = beta) - (dg, db, dx) -end - -function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, - dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T}, - running_mean::CuArray{T}, running_var::CuArray{T}, - momentum; cache = nothing, eps = T(1e-5), - alpha = T(1), beta = T(0), - dalpha = T(1), dbeta = T(0), training = true) where T<:Union{Float32, Float64} - if training - xd = TensorDesc(x) - dyd = TensorDesc(dy) - dxd = TensorDesc(dx) - gd = TensorDesc(T, _wsize(x)) - if cache !== nothing - mean, ivar = cache.mean, cache.ivar - info("mean and ivar are fetched from the cache") - else - mean, ivar = CU_NULL, CU_NULL - end - - if eps < CUDNN.CUDNN_BN_MIN_EPSILON - eps = CUDNN.CUDNN_BN_MIN_EPSILON - end - - CUDNN.cudnnBatchNormalizationBackward(handle(), CUDNN.CUDNN_BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), Ref(T(dalpha)), Ref(T(dbeta)), xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar) - else - ivar = 1 ./ sqrt.(reshape(running_var, _wsize(x)) .+ eps) - dx .= dy .* reshape(g, _wsize(x)) .* ivar - dg .= squeeze(sum(dy .* (x .- reshape(running_mean, _wsize(x))) .* ivar, _reddims(dy)), dims = (1,2,4)) - db .= squeeze(sum(dy, _reddims(dy)), dims = (1,2,4)) - end -end - -# Flux Interface +import ..Flux: data +import CuArrays.CUDNN: batchnorm, ∇batchnorm (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, cache = nothing) where T<:Union{Float32, Float64} = BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = Flux.istraining())) diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index bbd4e122..edbf58c5 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -1,190 +1,7 @@ -using CuArrays.CUDNN: handle, cudnnDataType, TensorDesc, FilterDesc - -import CuArrays.CUDAdrv: CU_NULL - -using LinearAlgebra - -const RNN_RELU = 0 # Stock RNN with ReLu activation -const RNN_TANH = 1 # Stock RNN with tanh activation -const LSTM = 2 # LSTM with no peephole connections -const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1) - -const LINEAR_INPUT = 0 -const SKIP_INPUT = 1 - -const UNIDIRECTIONAL = 0 -const BIDIRECTIONAL = 1 - -const RNN_ALGO_STANDARD = 0 -const RNN_ALGO_PERSIST_STATIC = 1 -const RNN_ALGO_PERSIST_DYNAMIC = 2 - -# param layout: -# RNN: [weight, bias] × [input, hidden] -# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem] -# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output] - -function params(w::CuVector, input, hidden, n = 1) - slice(offset, shape) = reshape(view(w, offset.+(1:prod(shape))), shape) - wx = slice(0, (input, hidden*n)) - wh = slice(length(wx), (hidden, hidden*n)) - bias = view(w, length(wx)+length(wh) .+ (1:hidden*n)) - (wx, wh), bias -end - -mutable struct RNNDesc{T} - mode::Int - input::Int - hidden::Int - params::CuVector{T} - weights::NTuple{2,CuMatrix{T}} - bias::CuVector{T} - ptr::Ptr{Nothing} -end - -Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr - -function rnnParamSize(T, r, input) - size = Csize_t[0] - CUDNN.cudnnGetRNNParamsSize(handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T)) - return Int(size[])÷sizeof(T) -end - -ngates(mode) = [1, 1, 4, 3][mode+1] -ngates(r::RNNDesc) = ngates(r.mode) - -function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T - d = [C_NULL] - CUDNN.cudnnCreateRNNDescriptor(d) - - dropoutDesc = DropoutDesc(0) - inputMode = LINEAR_INPUT - direction = UNIDIRECTIONAL - algo = RNN_ALGO_STANDARD - CUDNN.cudnnSetRNNDescriptor_v6(handle(),d[],hidden,layers,dropoutDesc,CUDNN.cudnnRNNInputMode_t(inputMode),CUDNN.cudnnDirectionMode_t(direction),CUDNN.cudnnRNNMode_t(mode),CUDNN.cudnnRNNAlgo_t(algo),cudnnDataType(T)) - - w =CuArrays.zeros(T, rnnParamSize(T, d[], input)) - # TODO: avoid reserve allocation here - rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[]) - finalizer(rd) do x - CUDNN.cudnnDestroyRNNDescriptor(x) - end - return rd -end - -function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc) - size = Csize_t[0] - CUDNN.cudnnGetRNNWorkspaceSize(handle(), r, seqlen, xdesc, size) - return Int(size[]) -end - -const workspace = [CuVector{UInt8}(undef, 1)] - -getworkspace(bytes) = - length(workspace[]) ≥ bytes ? - workspace[] : - (workspace[] = CuVector{UInt8}(undef, bytes)) - -getworkspace(r::RNNDesc, seqlen, xdesc) = - getworkspace(rnnWorkspaceSize(r, seqlen, xdesc)) - -function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc) - size = Csize_t[0] - CUDNN.cudnnGetRNNTrainingReserveSize(handle(), r, seqlen, xdesc, size) - return Int(size[]) -end - -function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, - workspace, reserve=nothing) where T - if reserve == nothing - CUDNN.cudnnRNNForwardInference(handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, - hod, ho, cod, co, workspace, length(workspace)) - else - CUDNN.cudnnRNNForwardTraining(handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, - hod, ho, cod, co, workspace, length(workspace), reserve, length(reserve)) - end -end - -xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))] - -hDesc(h::Nothing) = C_NULL, CU_NULL -hDesc(x::Integer) = (@assert x == 0; hDesc(nothing)) -function hDesc(h::CuArray) - TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h -end - -# TODO: can we just manipulate strides here? -# TODO: should use repmat, but this isn't implemented. -hBatch(x::AbstractVector, h::CuVector) = h -hBatch(x::AbstractMatrix, h::CuVector) = h .*CuArrays.ones(1, size(x, 2)) -hBatch(x::AbstractMatrix, h::CuMatrix) = h .*CuArrays.ones(1, size(h,2) == 1 ? size(x,2) : 1) - -function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T - h = hBatch(x, h_) - c = c_ == nothing ? nothing : hBatch(x, c_) - @assert size(x, 1) == rnn.input - @assert size(h, 1) == rnn.hidden - @assert size(x, 2) == size(h, 2) - seqLength = 1 - xdesc = xDesc(x) - y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2)) - ho = similar(h) - ydesc = xDesc(y) - workspace = getworkspace(rnn, seqLength, xdesc) - reserve = train == Val{true} ? - CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) : - nothing - co = c == nothing ? c : similar(c) - cudnnRNNForward(rnn, seqLength, - xdesc, x, - hDesc(h)..., - hDesc(c)..., - FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, - ydesc, y, - hDesc(ho)..., - hDesc(co)..., - workspace, reserve) - result = c == nothing ? (y, ho) : (y, ho, co) - return train == Val{true} ? (reserve, result) : result -end - -forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T = - forward(rnn, x, h, c, Val{true}) - -function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T - # Same as above, any more efficient way? - dy = dy_ isa Integer ? zero(y) : dy_ - yd = xDesc(y) - dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2)) - dh = similar(h) - dc = c == nothing ? nothing : similar(c) - workspace = getworkspace(rnn, 1, yd) - CUDNN.cudnnRNNBackwardData(handle(), rnn, 1, - yd, y, yd, dy, hDesc(dho)..., hDesc(dco)..., - FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, - hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)..., - workspace, length(workspace), reserve, length(reserve)) - return c == nothing ? (dx, dh) : (dx, dh, dc) -end - -backwardData(rnn, y, dy, dho, hx, reserve) = - backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve) - -function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T - dw = zero(rnn.params) - CUDNN.cudnnRNNBackwardWeights(handle(), rnn, 1, - xDesc(x), x, hDesc(h)..., xDesc(y), y, - workspace[], length(workspace[]), - FilterDesc(T, (1, 1, length(dw))), dw, - reserve, length(reserve)) - return params(dw, rnn.input, rnn.hidden, ngates(rnn)) -end - -# Interface - import ..Flux: Flux, relu using CuArrays.CUDAnative using CuArrays: @cuindex, cudims +using LinearAlgebra function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray) function kernel(dst, src) @@ -202,7 +19,7 @@ CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}} CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}} CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}} -function copyparams!(m::CuRNNs, d::RNNDesc) +function copyparams!(m::CuRNNs, d::CUDNN.RNNDesc) Wi, Wh = d.weights copy_transpose!(Wi, m.Wi) copy_transpose!(Wh, m.Wh) @@ -210,19 +27,19 @@ function copyparams!(m::CuRNNs, d::RNNDesc) return end -function RNNDesc(m::CuRNNs{T}) where T +function CUDNN.RNNDesc(m::CuRNNs{T}) where T h, i = length(m.h), size(m.Wi, 2) mode = m isa CuRNN ? - (m.σ == tanh ? RNN_TANH : RNN_RELU) : - m isa CuGRU ? GRU : LSTM - r = RNNDesc{T}(mode, i, h) + (m.σ == tanh ? CUDNN.CUDNN_RNN_TANH : CUDNN.CUDNN_RNN_RELU) : + m isa CuGRU ? CUDNN.CUDNN_GRU : CUDNN.CUDNN_LSTM + r = CUDNN.RNNDesc{T}(mode, i, h) return r end const descs = WeakKeyDict() function desc(rnn) - d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn)) + d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn)) copyparams!(rnn, d) return d end @@ -230,17 +47,17 @@ end using ..Flux: @adjoint function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} - y, h′ = forward(desc(m), x, h) + y, h′ = CUDNN.forward(desc(m), x, h) return h′, y end function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} - y, h′ = forward(desc(m), x, h) + y, h′ = CUDNN.forward(desc(m), x, h) return h′, y end function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} - y, h′, c′ = forward(desc(m), x, h[1], h[2]) + y, h′, c′ = CUDNN.forward(desc(m), x, h[1], h[2]) return (h′, c′), y end @@ -257,12 +74,12 @@ unbroadcast(x::AbstractArray, Δ) = for RNN in (CuRNN, CuGRU) @eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} - reserve, (y, ho) = forwardTrain(desc(m), x, h) + reserve, (y, ho) = CUDNN.forwardTrain(desc(m), x, h) (ho, y), function (Δ) dho, dy = Δ - h_ = hBatch(x, h) - dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) - (dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve) + h_ = CUDNN.hBatch(x, h) + dx, dh = CUDNN.backwardData(descs[m], y, dy, dho, h_, reserve) + (dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve) dm = Ref{Any}((σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing)) (dm, unbroadcast(h, dh), dx) end @@ -270,14 +87,14 @@ for RNN in (CuRNN, CuGRU) end @adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} - reserve, (y, ho, co) = forwardTrain(desc(m), x, h, c) + reserve, (y, ho, co) = CUDNN.forwardTrain(desc(m), x, h, c) ((ho, co), y), function (Δ) dhc, dy = Δ dho, dco = dhc === nothing ? (nothing, nothing) : dhc - h_ = hBatch(x, h) - c_ = hBatch(x, c) - dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) - (dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve) + h_ = CUDNN.hBatch(x, h) + c_ = CUDNN.hBatch(x, c) + dx, dh, dc = CUDNN.backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) + (dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve) dm = Ref{Any}((Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing)) (dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx) end From 6ea2557c468090c64ced5e831a8cdd990ecb5281 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 30 Aug 2019 13:41:15 +0200 Subject: [PATCH 4/8] Use correct CuArrays branch for CI. --- Manifest.toml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 17eb544e..e54c4a92 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -105,8 +105,10 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" version = "4.0.0" [[CuArrays]] -deps = ["AbstractFFTs", "Adapt", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"] -git-tree-sha1 = "46b48742a84bb839e74215b7e468a4a1c6ba30f9" +deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"] +git-tree-sha1 = "8189fcb50b24998bad7518e52443fdb542403093" +repo-rev = "tb/flux" +repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git" uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" version = "1.2.1" @@ -264,7 +266,7 @@ uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" version = "0.3.7" [[Pkg]] -deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" [[Printf]] From c5e56b7e04fcc24d240c3ca8711e3174fb29c82f Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Tue, 17 Sep 2019 17:22:35 +0100 Subject: [PATCH 5/8] move setweights and copy_transpose --- Manifest.toml | 2 +- Project.toml | 1 - src/cuda/curnn.jl | 22 +--------------------- test/cuda/curnn.jl | 4 ++-- 4 files changed, 4 insertions(+), 25 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index e5c84399..299a40b5 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -106,7 +106,7 @@ version = "4.0.0" [[CuArrays]] deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"] -git-tree-sha1 = "155349d2c40568a23cbc4599f0e17e2fdf1bbbcc" +git-tree-sha1 = "63b4a10d3a4f22ef215d0970483b18296717d1fb" repo-rev = "tb/flux" repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git" uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" diff --git a/Project.toml b/Project.toml index 2fcdc943..7cd78984 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,6 @@ Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index 19f6e9df..86422d03 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -1,32 +1,12 @@ import ..Flux: Flux, relu using CuArrays.CUDAnative using CuArrays: @cuindex, cudims -using LinearAlgebra - -function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray) - function kernel(dst, src) - I = @cuindex dst - dst[I...] = src[reverse(I)...] - return - end - blk, thr = cudims(dst) - @cuda blocks=blk threads=thr kernel(dst, src) - return dst -end CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}} CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}} CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}} CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}} -function copyparams!(m::CuRNNs, d::CUDNN.RNNDesc) - Wi, Wh = d.weights - copy_transpose!(Wi, m.Wi) - copy_transpose!(Wh, m.Wh) - copy_transpose!(d.bias, m.b) - return -end - function CUDNN.RNNDesc(m::CuRNNs{T}) where T h, i = length(m.h), size(m.Wi, 2) mode = m isa CuRNN ? @@ -40,7 +20,7 @@ const descs = WeakKeyDict() function desc(rnn) d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn)) - copyparams!(rnn, d) + CUDNN.setweights!(d, rnn.Wi, rnn.Wh, rnn.b) return d end diff --git a/test/cuda/curnn.jl b/test/cuda/curnn.jl index 1e834d14..e417ea58 100644 --- a/test/cuda/curnn.jl +++ b/test/cuda/curnn.jl @@ -22,8 +22,8 @@ end rand(10, batch_size) cux = gpu(x) - y, back = forward((r, x) -> (r(x)), rnn, x) - cuy, cuback = forward((r, x) -> (r(x)), curnn, cux) + y, back = forward((r, x) -> r(x), rnn, x) + cuy, cuback = forward((r, x) -> r(x), curnn, cux) @test y ≈ collect(cuy) @test haskey(Flux.CUDA.descs, curnn.cell) From 46bc8e5e648b5f5fe2811b8c21912367437cbb47 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Thu, 26 Sep 2019 17:14:18 +0100 Subject: [PATCH 6/8] move pullbacks to CuArrays --- Manifest.toml | 12 ++++++------ src/cuda/curnn.jl | 27 +++++++++++---------------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 299a40b5..d10fc71b 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -46,9 +46,9 @@ version = "0.6.2" [[CUDAapi]] deps = ["Libdl", "Logging"] -git-tree-sha1 = "9b2b4b71d6b7f946c9689bb4dea03ff92e3c7091" +git-tree-sha1 = "e063efb91cfefd7e6afd92c435d01398107a500b" uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3" -version = "1.1.0" +version = "1.2.0" [[CUDAdrv]] deps = ["CUDAapi", "Libdl", "Printf"] @@ -105,8 +105,8 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" version = "4.0.0" [[CuArrays]] -deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"] -git-tree-sha1 = "63b4a10d3a4f22ef215d0970483b18296717d1fb" +deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"] +git-tree-sha1 = "4e638627673078c58b6e6bb789937822d83350ff" repo-rev = "tb/flux" repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git" uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" @@ -172,9 +172,9 @@ version = "0.10.3" [[GPUArrays]] deps = ["Adapt", "FFTW", "FillArrays", "LinearAlgebra", "Printf", "Random", "Serialization", "StaticArrays", "Test"] -git-tree-sha1 = "b5009ac44b141ded5e6f04c4db83807970f56e91" +git-tree-sha1 = "77e27264276fe97a7e7fb928bf8999a145abc018" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "1.0.2" +version = "1.0.3" [[IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index 86422d03..fb454729 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -56,7 +56,7 @@ unbroadcast(x::AbstractArray, Δ) = coerce_cuda(x::Union{CuArray,Nothing}) = x coerce_cuda(x::Tuple) = coerce_cuda.(x) -coerce_cuda(x) = x .+ CuArrays.fill(0) +coerce_cuda(x::AbstractArray) = x .+ CuArrays.fill(0) function struct_grad!(cx::Zygote.Context, x, x̄) for f in fieldnames(typeof(x)) @@ -69,28 +69,23 @@ end for RNN in (CuRNN, CuGRU) @eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} - reserve, (y, ho) = CUDNN.forwardTrain(desc(m), x, h) + (y, ho), back = CUDNN.pullback(desc(m), x, h) (ho, y), function (Δ) - dho, dy = coerce_cuda(Δ) - h_ = CUDNN.hBatch(x, h) - dx, dh = CUDNN.backwardData(descs[m], y, dy, dho, h_, reserve) - (dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve) - dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing)) - (dm, unbroadcast(h, dh), dx) + dho, dy = coerce_cuda(Δ) # Support FillArrays etc. + m̄ = back(dy, dho) + dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing)) + (dm, unbroadcast(h, m̄.h), m̄.x) end end end @adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} - reserve, (y, ho, co) = CUDNN.forwardTrain(desc(m), x, h, c) + (y, ho, co), back = CUDNN.pullback(desc(m), x, h, c) ((ho, co), y), function (Δ) - dhc, dy = coerce_cuda(Δ) + dhc, dy = coerce_cuda(Δ) # Support FillArrays etc. dho, dco = dhc === nothing ? (nothing, nothing) : dhc - h_ = CUDNN.hBatch(x, h) - c_ = CUDNN.hBatch(x, c) - dx, dh, dc = CUDNN.backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) - (dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve) - dm = struct_grad!(__context__, m, (Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing)) - (dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx) + m̄ = back(dy, dho, dco) + dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing,c=nothing)) + (dm, (unbroadcast(h, m̄.h), unbroadcast(c, m̄.c)), m̄.x) end end From 691a29cf32bb01e9ca528ab869d72a17a1dec3a4 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Fri, 27 Sep 2019 14:15:58 +0100 Subject: [PATCH 7/8] cudnn bug is fixed --- Manifest.toml | 2 +- test/cuda/cuda.jl | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index d10fc71b..9919a94d 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -106,7 +106,7 @@ version = "4.0.0" [[CuArrays]] deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"] -git-tree-sha1 = "4e638627673078c58b6e6bb789937822d83350ff" +git-tree-sha1 = "cc22ec1abd471b4529883a8174944b513d75ab33" repo-rev = "tb/flux" repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git" uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 3508e561..20399ef7 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -51,9 +51,7 @@ end end if CuArrays.libcudnn != nothing - @info "Testing Flux/CUDNN" - include("cudnn.jl") - if !haskey(ENV, "CI_DISABLE_CURNN_TEST") - include("curnn.jl") - end + @info "Testing Flux/CUDNN" + include("cudnn.jl") + include("curnn.jl") end From e287982b7897c2674358e7a753570b3a5235a8f4 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Fri, 27 Sep 2019 14:55:30 +0100 Subject: [PATCH 8/8] use CuArrays master --- Manifest.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 9919a94d..4d825f17 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -106,8 +106,8 @@ version = "4.0.0" [[CuArrays]] deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"] -git-tree-sha1 = "cc22ec1abd471b4529883a8174944b513d75ab33" -repo-rev = "tb/flux" +git-tree-sha1 = "45683305171430978c17f496969dc9b6d3094a51" +repo-rev = "master" repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git" uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" version = "1.3.0"