diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index 15126aca..eb28abcf 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -2,6 +2,10 @@ module CUDA using ..CuArrays -CuArrays.libcudnn != nothing && include("cudnn.jl") +if isdefined(CuArrays, :libcudnn_handle) + handle() = CuArrays.libcudnn_handle[] +else + handle() = CuArrays.CUDNN.handle() +end end