From 0c19dad700b16f38dbdf6382cb1a5afd1e9e6f11 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 6 Nov 2018 12:39:54 +0000 Subject: [PATCH] include cudnn.jl --- src/cuda/cuda.jl | 6 +++--- src/cuda/cudnn.jl | 9 +++++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index eb28abcf..dc5ca272 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -2,10 +2,10 @@ module CUDA using ..CuArrays -if isdefined(CuArrays, :libcudnn_handle) - handle() = CuArrays.libcudnn_handle[] +if CuArrays.libcudnn != nothing + include("cudnn.jl") else - handle() = CuArrays.CUDNN.handle() + @warn("CUDNN is not installed, some functionality will not be available.") end end diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index f1c64226..3bddfbe2 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -1,8 +1,13 @@ using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnDataType, TensorDesc, FilterDesc - using LinearAlgebra +if isdefined(CuArrays, :libcudnn_handle) + handle() = CuArrays.libcudnn_handle[] +else + handle() = CuArrays.CUDNN.handle() +end + mutable struct DropoutDesc ptr::Ptr{Nothing} states::CuVector{UInt8} @@ -91,7 +96,7 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[]) finalizer(rd) do x @check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x) - end + end return rd end