include cudnn.jl

This commit is contained in:
Mike J Innes 2018-11-06 12:39:54 +00:00
parent 4ba891f666
commit 0c19dad700
2 changed files with 10 additions and 5 deletions

View File

@ -2,10 +2,10 @@ module CUDA
using ..CuArrays
if isdefined(CuArrays, :libcudnn_handle)
handle() = CuArrays.libcudnn_handle[]
if CuArrays.libcudnn != nothing
include("cudnn.jl")
else
handle() = CuArrays.CUDNN.handle()
@warn("CUDNN is not installed, some functionality will not be available.")
end
end

View File

@ -1,8 +1,13 @@
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t,
cudnnDataType, TensorDesc, FilterDesc
using LinearAlgebra
if isdefined(CuArrays, :libcudnn_handle)
handle() = CuArrays.libcudnn_handle[]
else
handle() = CuArrays.CUDNN.handle()
end
mutable struct DropoutDesc
ptr::Ptr{Nothing}
states::CuVector{UInt8}
@ -91,7 +96,7 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
finalizer(rd) do x
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
end
end
return rd
end