From 6b4e114d5dc6665747a17469eb6f26ccc6b1c2ce Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 26 Jan 2018 12:16:34 +0000 Subject: [PATCH] rnnParamSize --- src/cuda/cudnn.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 72246a85..7929c6e2 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -1,4 +1,5 @@ -using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle, cudnnDataType +using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle, + cudnnDataType, TensorDesc mutable struct DropoutDesc ptr::Ptr{Void} @@ -49,7 +50,7 @@ function RNNDesc(T, mode, input, hidden; layers = 1) finalizer(rd, x -> @check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x)) - dropoutDesc = DropoutDesc() + dropoutDesc = DropoutDesc(0) inputMode = LINEAR_INPUT direction = UNIDIRECTIONAL algo = RNN_ALGO_STANDARD @@ -57,3 +58,10 @@ function RNNDesc(T, mode, input, hidden; layers = 1) libcudnn_handle[],rd,hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T)) return rd end + +function rnnParamSize(r::RNNDesc, x) + size = Csize_t[0] + @check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint), + libcudnn_handle[], r, TensorDesc(x), size, cudnnDataType(eltype(x))) + return Int(size[]) +end