From 3fb83d642dc19e5fb5458f3a2344642aea236bc5 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 26 Jan 2018 15:28:39 +0000 Subject: [PATCH] rnnWorkspaceSize --- src/cuda/cudnn.jl | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 7929c6e2..7cd39e4b 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -38,15 +38,18 @@ const RNN_ALGO_PERSIST_STATIC = 1 const RNN_ALGO_PERSIST_DYNAMIC = 2 mutable struct RNNDesc + T::Type + input::Int + hidden::Int ptr::Ptr{Void} end Base.unsafe_convert(::Type{Ptr{Void}}, d::RNNDesc) = d.ptr -function RNNDesc(T, mode, input, hidden; layers = 1) +function RNNDesc(T::Type, mode::Int, input::Int, hidden::Int; layers = 1) d = [C_NULL] @check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d) - rd = RNNDesc(d[]) + rd = RNNDesc(T, input, hidden, d[]) finalizer(rd, x -> @check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x)) @@ -55,13 +58,20 @@ function RNNDesc(T, mode, input, hidden; layers = 1) direction = UNIDIRECTIONAL algo = RNN_ALGO_STANDARD @check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint), - libcudnn_handle[],rd,hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T)) + libcudnn_handle[],rd,hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(rd.T)) return rd end -function rnnParamSize(r::RNNDesc, x) +function rnnWorkspaceSize(r::RNNDesc) size = Csize_t[0] - @check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint), - libcudnn_handle[], r, TensorDesc(x), size, cudnnDataType(eltype(x))) + @check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Ptr{Ptr{Void}},Ptr{Csize_t}), + libcudnn_handle[], r, 1, [TensorDesc(r.T, (1,r.input,1))], size) + return Int(size[]) +end + +function rnnParamSize(r::RNNDesc) + size = Csize_t[0] + @check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint), + libcudnn_handle[], r, TensorDesc(r.T, (1,r.input,1)), size, cudnnDataType(r.T)) return Int(size[]) end