rnnParamSize
This commit is contained in:
parent
ee6c3e18a9
commit
6b4e114d5d
@ -1,4 +1,5 @@
|
||||
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle, cudnnDataType
|
||||
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
|
||||
cudnnDataType, TensorDesc
|
||||
|
||||
mutable struct DropoutDesc
|
||||
ptr::Ptr{Void}
|
||||
@ -49,7 +50,7 @@ function RNNDesc(T, mode, input, hidden; layers = 1)
|
||||
finalizer(rd, x ->
|
||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
||||
|
||||
dropoutDesc = DropoutDesc()
|
||||
dropoutDesc = DropoutDesc(0)
|
||||
inputMode = LINEAR_INPUT
|
||||
direction = UNIDIRECTIONAL
|
||||
algo = RNN_ALGO_STANDARD
|
||||
@ -57,3 +58,10 @@ function RNNDesc(T, mode, input, hidden; layers = 1)
|
||||
libcudnn_handle[],rd,hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
||||
return rd
|
||||
end
|
||||
|
||||
function rnnParamSize(r::RNNDesc, x)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint),
|
||||
libcudnn_handle[], r, TensorDesc(x), size, cudnnDataType(eltype(x)))
|
||||
return Int(size[])
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user