From af0c5523ff4cf63d97de0c1dc12fabb0b6b89b10 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 26 Jan 2018 15:35:14 +0000 Subject: [PATCH] rnnTrainingReserveSize --- src/cuda/cudnn.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 7cd39e4b..e60fcb5c 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -69,6 +69,13 @@ function rnnWorkspaceSize(r::RNNDesc) return Int(size[]) end +function rnnTrainingReserveSize(r::RNNDesc) + size = Csize_t[0] + @check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Void}, Ptr{Void}, Cint, Ptr{Ptr{Void}}, Ptr{Csize_t}), + libcudnn_handle[], r, 1, [TensorDesc(r.T, (1,r.input,1))], size) + return Int(size[]) +end + function rnnParamSize(r::RNNDesc) size = Csize_t[0] @check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint),