From ee6c3e18a97d0feacd911135ad2df4199316fe5d Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 24 Jan 2018 18:45:24 +0000 Subject: [PATCH] basic RNNDesc --- src/Flux.jl | 2 ++ src/cuda/cuda.jl | 7 ++++++ src/cuda/cudnn.jl | 59 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+) create mode 100644 src/cuda/cuda.jl create mode 100644 src/cuda/cudnn.jl diff --git a/src/Flux.jl b/src/Flux.jl index bc7a3a48..179d09c2 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -34,4 +34,6 @@ include("layers/normalisation.jl") include("data/Data.jl") +@require CuArrays include("cuda/cuda.jl") + end # module diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl new file mode 100644 index 00000000..eaa3fe00 --- /dev/null +++ b/src/cuda/cuda.jl @@ -0,0 +1,7 @@ +module CUDA + +using CuArrays + +CuArrays.cudnn_available() && include("cudnn.jl") + +end diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl new file mode 100644 index 00000000..72246a85 --- /dev/null +++ b/src/cuda/cudnn.jl @@ -0,0 +1,59 @@ +using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle, cudnnDataType + +mutable struct DropoutDesc + ptr::Ptr{Void} + states::CuVector{UInt8} +end + +Base.unsafe_convert(::Type{Ptr{Void}}, dd::DropoutDesc) = dd.ptr + +function DropoutDesc(ρ::Real; seed::Integer=0) + d = [C_NULL] + s = Csize_t[0] + @check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Void}},), d) + @check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Csize_t}),libcudnn_handle[],s) + states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0? + desc = DropoutDesc(d[], states) + @check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Void},Cfloat,Ptr{Void},Csize_t,Culonglong), + desc,libcudnn_handle[],ρ,states,length(states),seed) + finalizer(desc, x -> + @check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x)) + return desc +end + +const RNN_RELU = 0 # Stock RNN with ReLu activation +const RNN_TANH = 1 # Stock RNN with tanh activation +const LSTM = 2 # LSTM with no peephole connections +const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1) + +const LINEAR_INPUT = 0 +const SKIP_INPUT = 1 + +const UNIDIRECTIONAL = 0 +const BIDIRECTIONAL = 1 + +const RNN_ALGO_STANDARD = 0 +const RNN_ALGO_PERSIST_STATIC = 1 +const RNN_ALGO_PERSIST_DYNAMIC = 2 + +mutable struct RNNDesc + ptr::Ptr{Void} +end + +Base.unsafe_convert(::Type{Ptr{Void}}, d::RNNDesc) = d.ptr + +function RNNDesc(T, mode, input, hidden; layers = 1) + d = [C_NULL] + @check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d) + rd = RNNDesc(d[]) + finalizer(rd, x -> + @check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x)) + + dropoutDesc = DropoutDesc() + inputMode = LINEAR_INPUT + direction = UNIDIRECTIONAL + algo = RNN_ALGO_STANDARD + @check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint), + libcudnn_handle[],rd,hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T)) + return rd +end