Update the libcudnn_handle

This commit is contained in:
Avik Pal 2018-10-26 10:24:30 +05:30
parent ccd86d4795
commit b838c0bc04
2 changed files with 15 additions and 16 deletions

View File

@ -1,5 +1,5 @@
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc
cudnnBatchNormMode_t, cudnnHandle_t, handle, cudnnDataType, TensorDesc, FilterDesc
import ..Flux: data
mutable struct DropoutDesc
@ -13,11 +13,11 @@ function DropoutDesc(ρ::Real; seed::Integer=0)
d = [C_NULL]
s = Csize_t[0]
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Nothing}},), d)
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),libcudnn_handle[],s)
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s)
states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0?
desc = DropoutDesc(d[], states)
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
desc,libcudnn_handle[],ρ,states,length(states),seed)
desc,handle(),ρ,states,length(states),seed)
finalizer(desc) do x
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
end
@ -84,7 +84,7 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
Ptr{Nothing}, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T}),
libcudnn_handle[], BATCHNORM_SPATIAL,
handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
xd, x,
yd, y,
@ -105,7 +105,7 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
Ptr{Nothing}, Ptr{T}, Ptr{T},
Ptr{T}, Ptr{T},
Cdouble),
libcudnn_handle[], BATCHNORM_SPATIAL,
handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
xd, x,
yd, y,
@ -146,7 +146,6 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
end
if eps < BATCHNORM_MIN_EPS
# warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS)
eps = BATCHNORM_MIN_EPS
end
@ -159,7 +158,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T}),
libcudnn_handle[], BATCHNORM_SPATIAL,
handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
Ref(T(dalpha)), Ref(T(dbeta)),
xd, x,

View File

@ -1,5 +1,5 @@
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc
cudnnBatchNormMode_t, cudnnHandle_t, handle, cudnnDataType, TensorDesc, FilterDesc
using LinearAlgebra
@ -46,7 +46,7 @@ Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr
function rnnParamSize(T, r, input)
size = Csize_t[0]
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint),
libcudnn_handle[], r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
return Int(size[])÷sizeof(T)
end
@ -62,7 +62,7 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
direction = UNIDIRECTIONAL
algo = RNN_ALGO_STANDARD
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint),
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
handle(),d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
w = cuzeros(T, rnnParamSize(T, d[], input))
# TODO: avoid reserve allocation here
@ -76,7 +76,7 @@ end
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
size = Csize_t[0]
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}),
libcudnn_handle[], r, seqlen, xdesc, size)
handle(), r, seqlen, xdesc, size)
return Int(size[])
end
@ -93,7 +93,7 @@ getworkspace(r::RNNDesc, seqlen, xdesc) =
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
size = Csize_t[0]
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}),
libcudnn_handle[], r, seqlen, xdesc, size)
handle(), r, seqlen, xdesc, size)
return Int(size[])
end
@ -106,7 +106,7 @@ function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t),
libcudnn_handle[], rnn, seqlen,
handle(), rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace))
else
@ -114,7 +114,7 @@ function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
libcudnn_handle[], rnn, seqlen,
handle(), rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace), reserve, length(reserve))
end
@ -174,7 +174,7 @@ function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
end
@ -206,7 +206,7 @@ function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, d
Ptr{Nothing}, Csize_t, #ws
Ptr{Nothing}, Ptr{T}, #dw
Ptr{Nothing}, Csize_t), #rs
libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y,
handle(), rnn, seqlen, xd, x, hd, h, yd, y,
workspace, length(workspace), dwd, dw, reserve, length(reserve))
end