basic RNNDesc

This commit is contained in:
Mike J Innes 2018-01-24 18:45:24 +00:00
parent 842bf03051
commit ee6c3e18a9
3 changed files with 68 additions and 0 deletions

View File

@ -34,4 +34,6 @@ include("layers/normalisation.jl")
include("data/Data.jl")
@require CuArrays include("cuda/cuda.jl")
end # module

7
src/cuda/cuda.jl Normal file
View File

@ -0,0 +1,7 @@
module CUDA
using CuArrays
CuArrays.cudnn_available() && include("cudnn.jl")
end

59
src/cuda/cudnn.jl Normal file
View File

@ -0,0 +1,59 @@
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle, cudnnDataType
mutable struct DropoutDesc
ptr::Ptr{Void}
states::CuVector{UInt8}
end
Base.unsafe_convert(::Type{Ptr{Void}}, dd::DropoutDesc) = dd.ptr
function DropoutDesc(ρ::Real; seed::Integer=0)
d = [C_NULL]
s = Csize_t[0]
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Void}},), d)
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Csize_t}),libcudnn_handle[],s)
states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0?
desc = DropoutDesc(d[], states)
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Void},Cfloat,Ptr{Void},Csize_t,Culonglong),
desc,libcudnn_handle[],ρ,states,length(states),seed)
finalizer(desc, x ->
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
return desc
end
const RNN_RELU = 0 # Stock RNN with ReLu activation
const RNN_TANH = 1 # Stock RNN with tanh activation
const LSTM = 2 # LSTM with no peephole connections
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
const LINEAR_INPUT = 0
const SKIP_INPUT = 1
const UNIDIRECTIONAL = 0
const BIDIRECTIONAL = 1
const RNN_ALGO_STANDARD = 0
const RNN_ALGO_PERSIST_STATIC = 1
const RNN_ALGO_PERSIST_DYNAMIC = 2
mutable struct RNNDesc
ptr::Ptr{Void}
end
Base.unsafe_convert(::Type{Ptr{Void}}, d::RNNDesc) = d.ptr
function RNNDesc(T, mode, input, hidden; layers = 1)
d = [C_NULL]
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d)
rd = RNNDesc(d[])
finalizer(rd, x ->
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
dropoutDesc = DropoutDesc()
inputMode = LINEAR_INPUT
direction = UNIDIRECTIONAL
algo = RNN_ALGO_STANDARD
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint),
libcudnn_handle[],rd,hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
return rd
end