include cudnn.jl
This commit is contained in:
parent
4ba891f666
commit
0c19dad700
@ -2,10 +2,10 @@ module CUDA
|
||||
|
||||
using ..CuArrays
|
||||
|
||||
if isdefined(CuArrays, :libcudnn_handle)
|
||||
handle() = CuArrays.libcudnn_handle[]
|
||||
if CuArrays.libcudnn != nothing
|
||||
include("cudnn.jl")
|
||||
else
|
||||
handle() = CuArrays.CUDNN.handle()
|
||||
@warn("CUDNN is not installed, some functionality will not be available.")
|
||||
end
|
||||
|
||||
end
|
||||
|
@ -1,8 +1,13 @@
|
||||
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t,
|
||||
cudnnDataType, TensorDesc, FilterDesc
|
||||
|
||||
using LinearAlgebra
|
||||
|
||||
if isdefined(CuArrays, :libcudnn_handle)
|
||||
handle() = CuArrays.libcudnn_handle[]
|
||||
else
|
||||
handle() = CuArrays.CUDNN.handle()
|
||||
end
|
||||
|
||||
mutable struct DropoutDesc
|
||||
ptr::Ptr{Nothing}
|
||||
states::CuVector{UInt8}
|
||||
@ -91,7 +96,7 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
||||
finalizer(rd) do x
|
||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||
end
|
||||
end
|
||||
return rd
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user