rnnWorkspaceSize
This commit is contained in:
parent
6b4e114d5d
commit
3fb83d642d
@ -38,15 +38,18 @@ const RNN_ALGO_PERSIST_STATIC = 1
|
|||||||
const RNN_ALGO_PERSIST_DYNAMIC = 2
|
const RNN_ALGO_PERSIST_DYNAMIC = 2
|
||||||
|
|
||||||
mutable struct RNNDesc
|
mutable struct RNNDesc
|
||||||
|
T::Type
|
||||||
|
input::Int
|
||||||
|
hidden::Int
|
||||||
ptr::Ptr{Void}
|
ptr::Ptr{Void}
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.unsafe_convert(::Type{Ptr{Void}}, d::RNNDesc) = d.ptr
|
Base.unsafe_convert(::Type{Ptr{Void}}, d::RNNDesc) = d.ptr
|
||||||
|
|
||||||
function RNNDesc(T, mode, input, hidden; layers = 1)
|
function RNNDesc(T::Type, mode::Int, input::Int, hidden::Int; layers = 1)
|
||||||
d = [C_NULL]
|
d = [C_NULL]
|
||||||
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d)
|
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d)
|
||||||
rd = RNNDesc(d[])
|
rd = RNNDesc(T, input, hidden, d[])
|
||||||
finalizer(rd, x ->
|
finalizer(rd, x ->
|
||||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
||||||
|
|
||||||
@ -55,13 +58,20 @@ function RNNDesc(T, mode, input, hidden; layers = 1)
|
|||||||
direction = UNIDIRECTIONAL
|
direction = UNIDIRECTIONAL
|
||||||
algo = RNN_ALGO_STANDARD
|
algo = RNN_ALGO_STANDARD
|
||||||
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint),
|
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint),
|
||||||
libcudnn_handle[],rd,hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
libcudnn_handle[],rd,hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(rd.T))
|
||||||
return rd
|
return rd
|
||||||
end
|
end
|
||||||
|
|
||||||
function rnnParamSize(r::RNNDesc, x)
|
function rnnWorkspaceSize(r::RNNDesc)
|
||||||
size = Csize_t[0]
|
size = Csize_t[0]
|
||||||
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint),
|
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Ptr{Ptr{Void}},Ptr{Csize_t}),
|
||||||
libcudnn_handle[], r, TensorDesc(x), size, cudnnDataType(eltype(x)))
|
libcudnn_handle[], r, 1, [TensorDesc(r.T, (1,r.input,1))], size)
|
||||||
|
return Int(size[])
|
||||||
|
end
|
||||||
|
|
||||||
|
function rnnParamSize(r::RNNDesc)
|
||||||
|
size = Csize_t[0]
|
||||||
|
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint),
|
||||||
|
libcudnn_handle[], r, TensorDesc(r.T, (1,r.input,1)), size, cudnnDataType(r.T))
|
||||||
return Int(size[])
|
return Int(size[])
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user