rnnParamSize

This commit is contained in:
Mike J Innes 2018-01-26 12:16:34 +00:00
parent ee6c3e18a9
commit 6b4e114d5d

View File

@ -1,4 +1,5 @@
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle, cudnnDataType
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
cudnnDataType, TensorDesc
mutable struct DropoutDesc
ptr::Ptr{Void}
@ -49,7 +50,7 @@ function RNNDesc(T, mode, input, hidden; layers = 1)
finalizer(rd, x ->
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
dropoutDesc = DropoutDesc()
dropoutDesc = DropoutDesc(0)
inputMode = LINEAR_INPUT
direction = UNIDIRECTIONAL
algo = RNN_ALGO_STANDARD
@ -57,3 +58,10 @@ function RNNDesc(T, mode, input, hidden; layers = 1)
libcudnn_handle[],rd,hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
return rd
end
function rnnParamSize(r::RNNDesc, x)
size = Csize_t[0]
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint),
libcudnn_handle[], r, TensorDesc(x), size, cudnnDataType(eltype(x)))
return Int(size[])
end