include cudnn.jl

This commit is contained in:
Mike J Innes 2018-11-06 12:39:54 +00:00
parent 4ba891f666
commit 0c19dad700
2 changed files with 10 additions and 5 deletions

View File

@ -2,10 +2,10 @@ module CUDA
using ..CuArrays
if isdefined(CuArrays, :libcudnn_handle)
handle() = CuArrays.libcudnn_handle[]
if CuArrays.libcudnn != nothing
include("cudnn.jl")
else
handle() = CuArrays.CUDNN.handle()
@warn("CUDNN is not installed, some functionality will not be available.")
end
end

View File

@ -1,8 +1,13 @@
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t,
cudnnDataType, TensorDesc, FilterDesc
using LinearAlgebra
if isdefined(CuArrays, :libcudnn_handle)
handle() = CuArrays.libcudnn_handle[]
else
handle() = CuArrays.CUDNN.handle()
end
mutable struct DropoutDesc
ptr::Ptr{Nothing}
states::CuVector{UInt8}