diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 72246a85..7929c6e2 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -1,4 +1,5 @@ -using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle, cudnnDataType +using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle, + cudnnDataType, TensorDesc mutable struct DropoutDesc ptr::Ptr{Void} @@ -49,7 +50,7 @@ function RNNDesc(T, mode, input, hidden; layers = 1) finalizer(rd, x -> @check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x)) - dropoutDesc = DropoutDesc() + dropoutDesc = DropoutDesc(0) inputMode = LINEAR_INPUT direction = UNIDIRECTIONAL algo = RNN_ALGO_STANDARD @@ -57,3 +58,10 @@ function RNNDesc(T, mode, input, hidden; layers = 1) libcudnn_handle[],rd,hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T)) return rd end + +function rnnParamSize(r::RNNDesc, x) + size = Csize_t[0] + @check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint), + libcudnn_handle[], r, TensorDesc(x), size, cudnnDataType(eltype(x))) + return Int(size[]) +end