include cudnn.jl

This commit is contained in:
Mike J Innes 2018-11-06 12:39:54 +00:00
parent 4ba891f666
commit 0c19dad700
2 changed files with 10 additions and 5 deletions

View File

@ -2,10 +2,10 @@ module CUDA
using ..CuArrays using ..CuArrays
if isdefined(CuArrays, :libcudnn_handle) if CuArrays.libcudnn != nothing
handle() = CuArrays.libcudnn_handle[] include("cudnn.jl")
else else
handle() = CuArrays.CUDNN.handle() @warn("CUDNN is not installed, some functionality will not be available.")
end end
end end

View File

@ -1,8 +1,13 @@
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t,
cudnnDataType, TensorDesc, FilterDesc cudnnDataType, TensorDesc, FilterDesc
using LinearAlgebra using LinearAlgebra
if isdefined(CuArrays, :libcudnn_handle)
handle() = CuArrays.libcudnn_handle[]
else
handle() = CuArrays.CUDNN.handle()
end
mutable struct DropoutDesc mutable struct DropoutDesc
ptr::Ptr{Nothing} ptr::Ptr{Nothing}
states::CuVector{UInt8} states::CuVector{UInt8}