basic RNNDesc
This commit is contained in:
parent
842bf03051
commit
ee6c3e18a9
@ -34,4 +34,6 @@ include("layers/normalisation.jl")
|
||||
|
||||
include("data/Data.jl")
|
||||
|
||||
@require CuArrays include("cuda/cuda.jl")
|
||||
|
||||
end # module
|
||||
|
7
src/cuda/cuda.jl
Normal file
7
src/cuda/cuda.jl
Normal file
@ -0,0 +1,7 @@
|
||||
module CUDA
|
||||
|
||||
using CuArrays
|
||||
|
||||
CuArrays.cudnn_available() && include("cudnn.jl")
|
||||
|
||||
end
|
59
src/cuda/cudnn.jl
Normal file
59
src/cuda/cudnn.jl
Normal file
@ -0,0 +1,59 @@
|
||||
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle, cudnnDataType
|
||||
|
||||
mutable struct DropoutDesc
|
||||
ptr::Ptr{Void}
|
||||
states::CuVector{UInt8}
|
||||
end
|
||||
|
||||
Base.unsafe_convert(::Type{Ptr{Void}}, dd::DropoutDesc) = dd.ptr
|
||||
|
||||
function DropoutDesc(ρ::Real; seed::Integer=0)
|
||||
d = [C_NULL]
|
||||
s = Csize_t[0]
|
||||
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Void}},), d)
|
||||
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Csize_t}),libcudnn_handle[],s)
|
||||
states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0?
|
||||
desc = DropoutDesc(d[], states)
|
||||
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Void},Cfloat,Ptr{Void},Csize_t,Culonglong),
|
||||
desc,libcudnn_handle[],ρ,states,length(states),seed)
|
||||
finalizer(desc, x ->
|
||||
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
||||
return desc
|
||||
end
|
||||
|
||||
const RNN_RELU = 0 # Stock RNN with ReLu activation
|
||||
const RNN_TANH = 1 # Stock RNN with tanh activation
|
||||
const LSTM = 2 # LSTM with no peephole connections
|
||||
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
|
||||
|
||||
const LINEAR_INPUT = 0
|
||||
const SKIP_INPUT = 1
|
||||
|
||||
const UNIDIRECTIONAL = 0
|
||||
const BIDIRECTIONAL = 1
|
||||
|
||||
const RNN_ALGO_STANDARD = 0
|
||||
const RNN_ALGO_PERSIST_STATIC = 1
|
||||
const RNN_ALGO_PERSIST_DYNAMIC = 2
|
||||
|
||||
mutable struct RNNDesc
|
||||
ptr::Ptr{Void}
|
||||
end
|
||||
|
||||
Base.unsafe_convert(::Type{Ptr{Void}}, d::RNNDesc) = d.ptr
|
||||
|
||||
function RNNDesc(T, mode, input, hidden; layers = 1)
|
||||
d = [C_NULL]
|
||||
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d)
|
||||
rd = RNNDesc(d[])
|
||||
finalizer(rd, x ->
|
||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
||||
|
||||
dropoutDesc = DropoutDesc()
|
||||
inputMode = LINEAR_INPUT
|
||||
direction = UNIDIRECTIONAL
|
||||
algo = RNN_ALGO_STANDARD
|
||||
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint),
|
||||
libcudnn_handle[],rd,hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
||||
return rd
|
||||
end
|
Loading…
Reference in New Issue
Block a user