include cudnn.jl
This commit is contained in:
parent
4ba891f666
commit
0c19dad700
@ -2,10 +2,10 @@ module CUDA
|
|||||||
|
|
||||||
using ..CuArrays
|
using ..CuArrays
|
||||||
|
|
||||||
if isdefined(CuArrays, :libcudnn_handle)
|
if CuArrays.libcudnn != nothing
|
||||||
handle() = CuArrays.libcudnn_handle[]
|
include("cudnn.jl")
|
||||||
else
|
else
|
||||||
handle() = CuArrays.CUDNN.handle()
|
@warn("CUDNN is not installed, some functionality will not be available.")
|
||||||
end
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -1,8 +1,13 @@
|
|||||||
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t,
|
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t,
|
||||||
cudnnDataType, TensorDesc, FilterDesc
|
cudnnDataType, TensorDesc, FilterDesc
|
||||||
|
|
||||||
using LinearAlgebra
|
using LinearAlgebra
|
||||||
|
|
||||||
|
if isdefined(CuArrays, :libcudnn_handle)
|
||||||
|
handle() = CuArrays.libcudnn_handle[]
|
||||||
|
else
|
||||||
|
handle() = CuArrays.CUDNN.handle()
|
||||||
|
end
|
||||||
|
|
||||||
mutable struct DropoutDesc
|
mutable struct DropoutDesc
|
||||||
ptr::Ptr{Nothing}
|
ptr::Ptr{Nothing}
|
||||||
states::CuVector{UInt8}
|
states::CuVector{UInt8}
|
||||||
|
Loading…
Reference in New Issue
Block a user