Merge pull request #600 from FluxML/tb/cuptr
Adapt to the new CUDAdrv.CuPtr pointer type.
This commit is contained in:
commit
f2dc57f938
@ -1,17 +1,18 @@
|
|||||||
module CUDA
|
module CUDA
|
||||||
|
|
||||||
using ..CuArrays
|
using ..CuArrays
|
||||||
|
import ..CuArrays.CUDAdrv: CuPtr, CU_NULL
|
||||||
using Pkg.TOML
|
using Pkg.TOML
|
||||||
|
|
||||||
function version_check()
|
function version_check()
|
||||||
minor_version = 9
|
major_version = 1
|
||||||
project = joinpath(dirname(pathof(CuArrays)), "../Project.toml")
|
project = joinpath(dirname(pathof(CuArrays)), "../Project.toml")
|
||||||
project = TOML.parse(String(read(project)))
|
project = TOML.parse(String(read(project)))
|
||||||
version = VersionNumber(get(project, "version", "0.0.0"))
|
version = VersionNumber(get(project, "version", "0.0.0"))
|
||||||
if !(version.major == 0 && version.minor == minor_version)
|
if version.major != major_version
|
||||||
@warn """
|
@warn """
|
||||||
Flux is only supported with CuArrays v0.$minor_version.
|
Flux is only supported with CuArrays v$major_version.x.
|
||||||
Try running `] pin CuArrays@0.$minor_version`.
|
Try running `] pin CuArrays@$major_version`.
|
||||||
"""
|
"""
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -17,7 +17,7 @@ function DropoutDesc(ρ::Real; seed::Integer=0)
|
|||||||
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s)
|
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s)
|
||||||
states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0?
|
states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0?
|
||||||
desc = DropoutDesc(d[], states)
|
desc = DropoutDesc(d[], states)
|
||||||
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
|
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,CuPtr{Nothing},Csize_t,Culonglong),
|
||||||
desc,handle(),ρ,states,length(states),seed)
|
desc,handle(),ρ,states,length(states),seed)
|
||||||
finalizer(desc) do x
|
finalizer(desc) do x
|
||||||
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||||
@ -79,18 +79,18 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
|||||||
mean = zeros(CuArray{T}, dims...)
|
mean = zeros(CuArray{T}, dims...)
|
||||||
ivar = ones(CuArray{T}, dims...)
|
ivar = ones(CuArray{T}, dims...)
|
||||||
else
|
else
|
||||||
mean = C_NULL
|
mean = CU_NULL
|
||||||
ivar = C_NULL
|
ivar = CU_NULL
|
||||||
end
|
end
|
||||||
|
|
||||||
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
|
||||||
(cudnnHandle_t,cudnnBatchNormMode_t,
|
(cudnnHandle_t,cudnnBatchNormMode_t,
|
||||||
Ptr{T}, Ptr{T},
|
Ptr{T}, Ptr{T},
|
||||||
Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T}, CuPtr{T},
|
||||||
Cdouble, Ptr{T}, Ptr{T},
|
Cdouble, CuPtr{T}, CuPtr{T},
|
||||||
Cdouble, Ptr{T}, Ptr{T}),
|
Cdouble, CuPtr{T}, CuPtr{T}),
|
||||||
handle(), BATCHNORM_SPATIAL,
|
handle(), BATCHNORM_SPATIAL,
|
||||||
Ref(T(alpha)), Ref(T(beta)),
|
Ref(T(alpha)), Ref(T(beta)),
|
||||||
xd, x,
|
xd, x,
|
||||||
@ -107,10 +107,10 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
|||||||
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
|
||||||
(Ptr{cudnnHandle_t},cudnnBatchNormMode_t,
|
(Ptr{cudnnHandle_t},cudnnBatchNormMode_t,
|
||||||
Ptr{T}, Ptr{T},
|
Ptr{T}, Ptr{T},
|
||||||
Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T}, CuPtr{T},
|
||||||
Ptr{T}, Ptr{T},
|
CuPtr{T}, CuPtr{T},
|
||||||
Cdouble),
|
Cdouble),
|
||||||
handle(), BATCHNORM_SPATIAL,
|
handle(), BATCHNORM_SPATIAL,
|
||||||
Ref(T(alpha)), Ref(T(beta)),
|
Ref(T(alpha)), Ref(T(beta)),
|
||||||
@ -159,7 +159,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
|||||||
mean, ivar = cache.mean, cache.ivar
|
mean, ivar = cache.mean, cache.ivar
|
||||||
info("mean and ivar are fetched from the cache")
|
info("mean and ivar are fetched from the cache")
|
||||||
else
|
else
|
||||||
mean, ivar = C_NULL, C_NULL
|
mean, ivar = CU_NULL, CU_NULL
|
||||||
end
|
end
|
||||||
|
|
||||||
if eps < BATCHNORM_MIN_EPS
|
if eps < BATCHNORM_MIN_EPS
|
||||||
@ -170,11 +170,11 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
|||||||
(cudnnHandle_t,cudnnBatchNormMode_t,
|
(cudnnHandle_t,cudnnBatchNormMode_t,
|
||||||
Ptr{T}, Ptr{T},
|
Ptr{T}, Ptr{T},
|
||||||
Ptr{T}, Ptr{T},
|
Ptr{T}, Ptr{T},
|
||||||
Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T}, CuPtr{T}, CuPtr{T},
|
||||||
Cdouble, Ptr{T}, Ptr{T}),
|
Cdouble, CuPtr{T}, CuPtr{T}),
|
||||||
handle(), BATCHNORM_SPATIAL,
|
handle(), BATCHNORM_SPATIAL,
|
||||||
Ref(T(alpha)), Ref(T(beta)),
|
Ref(T(alpha)), Ref(T(beta)),
|
||||||
Ref(T(dalpha)), Ref(T(dbeta)),
|
Ref(T(dalpha)), Ref(T(dbeta)),
|
||||||
|
@ -101,18 +101,18 @@ function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd
|
|||||||
if reserve == nothing
|
if reserve == nothing
|
||||||
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
||||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Csize_t),
|
CuPtr{Nothing}, Csize_t),
|
||||||
handle(), rnn, seqlen,
|
handle(), rnn, seqlen,
|
||||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||||
workspace, length(workspace))
|
workspace, length(workspace))
|
||||||
else
|
else
|
||||||
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
||||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t),
|
||||||
handle(), rnn, seqlen,
|
handle(), rnn, seqlen,
|
||||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||||
workspace, length(workspace), reserve, length(reserve))
|
workspace, length(workspace), reserve, length(reserve))
|
||||||
@ -121,7 +121,7 @@ end
|
|||||||
|
|
||||||
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
||||||
|
|
||||||
hDesc(h::Nothing) = C_NULL, C_NULL
|
hDesc(h::Nothing) = C_NULL, CU_NULL
|
||||||
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
||||||
function hDesc(h::CuArray)
|
function hDesc(h::CuArray)
|
||||||
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
||||||
@ -169,10 +169,10 @@ function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho
|
|||||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
||||||
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
||||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
|
Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing},
|
||||||
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||||
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t),
|
||||||
handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
||||||
end
|
end
|
||||||
@ -199,12 +199,12 @@ function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, d
|
|||||||
workspace, reserve) where T
|
workspace, reserve) where T
|
||||||
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
||||||
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
|
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
|
||||||
Ptr{Ptr{Nothing}}, Ptr{T}, #x
|
Ptr{Ptr{Nothing}}, CuPtr{T}, #x
|
||||||
Ptr{Nothing}, Ptr{T}, #hx
|
Ptr{Nothing}, CuPtr{T}, #hx
|
||||||
Ptr{Ptr{Nothing}}, Ptr{T}, #y
|
Ptr{Ptr{Nothing}}, CuPtr{T}, #y
|
||||||
Ptr{Nothing}, Csize_t, #ws
|
CuPtr{Nothing}, Csize_t, #ws
|
||||||
Ptr{Nothing}, Ptr{T}, #dw
|
Ptr{Nothing}, CuPtr{T}, #dw
|
||||||
Ptr{Nothing}, Csize_t), #rs
|
CuPtr{Nothing}, Csize_t), #rs
|
||||||
handle(), rnn, seqlen, xd, x, hd, h, yd, y,
|
handle(), rnn, seqlen, xd, x, hd, h, yd, y,
|
||||||
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user