Move functionality over to CuArrays.

This commit is contained in:
Tim Besard 2019-08-30 08:39:51 +02:00
parent 1e7ff4f65d
commit 4942d7fcfd
2 changed files with 21 additions and 348 deletions

View File

@ -1,149 +1,5 @@
using CuArrays.CUDNN: handle, TensorDesc, FilterDesc
import CuArrays.CUDAdrv: CU_NULL
using LinearAlgebra
mutable struct DropoutDesc
ptr::Ptr{Nothing}
states::CuVector{UInt8}
end
Base.unsafe_convert(::Type{Ptr{Nothing}}, dd::DropoutDesc) = dd.ptr
function DropoutDesc(ρ::Real; seed::Integer=0)
d = [C_NULL]
s = Csize_t[0]
CUDNN.cudnnCreateDropoutDescriptor(d)
CUDNN.cudnnDropoutGetStatesSize(handle(), s)
states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0?
desc = DropoutDesc(d[], states)
CUDNN.cudnnSetDropoutDescriptor(desc, handle(), ρ, states, length(states), seed)
finalizer(desc) do x
CUDNN.cudnnDestroyDropoutDescriptor(x)
end
return desc
end
@inline _wsize(y) = (map(_ -> 1, size(y)[1:end-2])..., size(y)[end-1], 1)
@inline _reddims(y) = (collect(1:ndims(y)-2)..., ndims(y))
mutable struct BNCache
mean
ivar
end
BNCache() = BNCache(nothing, nothing)
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
# so reshape a 2D Tensor into 4D
batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2},
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
cache = nothing, alpha = T(1), beta = T(0),
eps = T(1e-5), training = true) where T<:Union{Float32, Float64} =
dropdims(batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), running_mean, running_var, momentum,
cache = cache, alpha = alpha, beta = beta, eps = eps, training = training), dims = (1, 2))
function batchnorm(g::CuArray{T}, b::CuArray{T}, x::Union{CuArray{T, 4},CuArray{T,5}},
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
cache = nothing, alpha = T(1), beta = T(0),
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
y = similar(x)
cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache,
alpha = alpha, beta = beta, eps = eps, training = training)
y
end
function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T},
momentum; cache = nothing,
alpha = T(1), beta = T(0),
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
dims = _wsize(x)
if eps < CUDNN.CUDNN_BN_MIN_EPSILON
# warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", CUDNN.CUDNN_BN_MIN_EPSILON)
eps = CUDNN.CUDNN_BN_MIN_EPSILON
end
xd = TensorDesc(x)
yd = TensorDesc(y)
gd = TensorDesc(T, dims)
if training
if cache !== nothing
mean = zeros(CuArray{T}, dims...)
ivar = ones(CuArray{T}, dims...)
else
mean = CU_NULL
ivar = CU_NULL
end
CUDNN.cudnnBatchNormalizationForwardTraining(handle(), CUDNN.CUDNN_BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), xd, x, yd, y, gd, g, b, momentum, running_mean, running_var, eps, mean, ivar)
if cache !== nothing
cache.mean = mean
cache.ivar = ivar
end
else
CUDNN.cudnnBatchNormalizationForwardInference(handle(), CUDNN.CUDNN_BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), xd, x, yd, y, gd, g, b, running_mean, running_var, eps)
end
end
function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T, 2},
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
cache = nothing, eps = T(1e-5), alpha = T(1),
beta = T(0), training = true) where T<:Union{Float32, Float64}
dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1),
size(dy, 2)), running_mean, running_var, momentum, cache = cache, eps = eps,
alpha = alpha, beta = beta, training = training)
(dg, db, dropdims(dx, dims = (1, 2)))
end
function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
cache = nothing, eps = T(1e-5), alpha = T(1),
beta = T(0), training = true) where T<:Union{Float32, Float64}
dg = similar(g)
db = similar(b)
dx = similar(x)
cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum),
training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
(dg, db, dx)
end
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T},
momentum; cache = nothing, eps = T(1e-5),
alpha = T(1), beta = T(0),
dalpha = T(1), dbeta = T(0), training = true) where T<:Union{Float32, Float64}
if training
xd = TensorDesc(x)
dyd = TensorDesc(dy)
dxd = TensorDesc(dx)
gd = TensorDesc(T, _wsize(x))
if cache !== nothing
mean, ivar = cache.mean, cache.ivar
info("mean and ivar are fetched from the cache")
else
mean, ivar = CU_NULL, CU_NULL
end
if eps < CUDNN.CUDNN_BN_MIN_EPSILON
eps = CUDNN.CUDNN_BN_MIN_EPSILON
end
CUDNN.cudnnBatchNormalizationBackward(handle(), CUDNN.CUDNN_BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), Ref(T(dalpha)), Ref(T(dbeta)), xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
else
ivar = 1 ./ sqrt.(reshape(running_var, _wsize(x)) .+ eps)
dx .= dy .* reshape(g, _wsize(x)) .* ivar
dg .= squeeze(sum(dy .* (x .- reshape(running_mean, _wsize(x))) .* ivar, _reddims(dy)), dims = (1,2,4))
db .= squeeze(sum(dy, _reddims(dy)), dims = (1,2,4))
end
end
# Flux Interface
import ..Flux: data
import CuArrays.CUDNN: batchnorm, ∇batchnorm
(BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = Flux.istraining()))

View File

@ -1,190 +1,7 @@
using CuArrays.CUDNN: handle, cudnnDataType, TensorDesc, FilterDesc
import CuArrays.CUDAdrv: CU_NULL
using LinearAlgebra
const RNN_RELU = 0 # Stock RNN with ReLu activation
const RNN_TANH = 1 # Stock RNN with tanh activation
const LSTM = 2 # LSTM with no peephole connections
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
const LINEAR_INPUT = 0
const SKIP_INPUT = 1
const UNIDIRECTIONAL = 0
const BIDIRECTIONAL = 1
const RNN_ALGO_STANDARD = 0
const RNN_ALGO_PERSIST_STATIC = 1
const RNN_ALGO_PERSIST_DYNAMIC = 2
# param layout:
# RNN: [weight, bias] × [input, hidden]
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
function params(w::CuVector, input, hidden, n = 1)
slice(offset, shape) = reshape(view(w, offset.+(1:prod(shape))), shape)
wx = slice(0, (input, hidden*n))
wh = slice(length(wx), (hidden, hidden*n))
bias = view(w, length(wx)+length(wh) .+ (1:hidden*n))
(wx, wh), bias
end
mutable struct RNNDesc{T}
mode::Int
input::Int
hidden::Int
params::CuVector{T}
weights::NTuple{2,CuMatrix{T}}
bias::CuVector{T}
ptr::Ptr{Nothing}
end
Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr
function rnnParamSize(T, r, input)
size = Csize_t[0]
CUDNN.cudnnGetRNNParamsSize(handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
return Int(size[])÷sizeof(T)
end
ngates(mode) = [1, 1, 4, 3][mode+1]
ngates(r::RNNDesc) = ngates(r.mode)
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
d = [C_NULL]
CUDNN.cudnnCreateRNNDescriptor(d)
dropoutDesc = DropoutDesc(0)
inputMode = LINEAR_INPUT
direction = UNIDIRECTIONAL
algo = RNN_ALGO_STANDARD
CUDNN.cudnnSetRNNDescriptor_v6(handle(),d[],hidden,layers,dropoutDesc,CUDNN.cudnnRNNInputMode_t(inputMode),CUDNN.cudnnDirectionMode_t(direction),CUDNN.cudnnRNNMode_t(mode),CUDNN.cudnnRNNAlgo_t(algo),cudnnDataType(T))
w =CuArrays.zeros(T, rnnParamSize(T, d[], input))
# TODO: avoid reserve allocation here
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
finalizer(rd) do x
CUDNN.cudnnDestroyRNNDescriptor(x)
end
return rd
end
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
size = Csize_t[0]
CUDNN.cudnnGetRNNWorkspaceSize(handle(), r, seqlen, xdesc, size)
return Int(size[])
end
const workspace = [CuVector{UInt8}(undef, 1)]
getworkspace(bytes) =
length(workspace[]) bytes ?
workspace[] :
(workspace[] = CuVector{UInt8}(undef, bytes))
getworkspace(r::RNNDesc, seqlen, xdesc) =
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
size = Csize_t[0]
CUDNN.cudnnGetRNNTrainingReserveSize(handle(), r, seqlen, xdesc, size)
return Int(size[])
end
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, reserve=nothing) where T
if reserve == nothing
CUDNN.cudnnRNNForwardInference(handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y,
hod, ho, cod, co, workspace, length(workspace))
else
CUDNN.cudnnRNNForwardTraining(handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y,
hod, ho, cod, co, workspace, length(workspace), reserve, length(reserve))
end
end
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
hDesc(h::Nothing) = C_NULL, CU_NULL
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
function hDesc(h::CuArray)
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
end
# TODO: can we just manipulate strides here?
# TODO: should use repmat, but this isn't implemented.
hBatch(x::AbstractVector, h::CuVector) = h
hBatch(x::AbstractMatrix, h::CuVector) = h .*CuArrays.ones(1, size(x, 2))
hBatch(x::AbstractMatrix, h::CuMatrix) = h .*CuArrays.ones(1, size(h,2) == 1 ? size(x,2) : 1)
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
h = hBatch(x, h_)
c = c_ == nothing ? nothing : hBatch(x, c_)
@assert size(x, 1) == rnn.input
@assert size(h, 1) == rnn.hidden
@assert size(x, 2) == size(h, 2)
seqLength = 1
xdesc = xDesc(x)
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
ho = similar(h)
ydesc = xDesc(y)
workspace = getworkspace(rnn, seqLength, xdesc)
reserve = train == Val{true} ?
CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
nothing
co = c == nothing ? c : similar(c)
cudnnRNNForward(rnn, seqLength,
xdesc, x,
hDesc(h)...,
hDesc(c)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
ydesc, y,
hDesc(ho)...,
hDesc(co)...,
workspace, reserve)
result = c == nothing ? (y, ho) : (y, ho, co)
return train == Val{true} ? (reserve, result) : result
end
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T =
forward(rnn, x, h, c, Val{true})
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
# Same as above, any more efficient way?
dy = dy_ isa Integer ? zero(y) : dy_
yd = xDesc(y)
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
dh = similar(h)
dc = c == nothing ? nothing : similar(c)
workspace = getworkspace(rnn, 1, yd)
CUDNN.cudnnRNNBackwardData(handle(), rnn, 1,
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
workspace, length(workspace), reserve, length(reserve))
return c == nothing ? (dx, dh) : (dx, dh, dc)
end
backwardData(rnn, y, dy, dho, hx, reserve) =
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
dw = zero(rnn.params)
CUDNN.cudnnRNNBackwardWeights(handle(), rnn, 1,
xDesc(x), x, hDesc(h)..., xDesc(y), y,
workspace[], length(workspace[]),
FilterDesc(T, (1, 1, length(dw))), dw,
reserve, length(reserve))
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
end
# Interface
import ..Flux: Flux, relu
using CuArrays.CUDAnative
using CuArrays: @cuindex, cudims
using LinearAlgebra
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src)
@ -202,7 +19,7 @@ CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
function copyparams!(m::CuRNNs, d::RNNDesc)
function copyparams!(m::CuRNNs, d::CUDNN.RNNDesc)
Wi, Wh = d.weights
copy_transpose!(Wi, m.Wi)
copy_transpose!(Wh, m.Wh)
@ -210,19 +27,19 @@ function copyparams!(m::CuRNNs, d::RNNDesc)
return
end
function RNNDesc(m::CuRNNs{T}) where T
function CUDNN.RNNDesc(m::CuRNNs{T}) where T
h, i = length(m.h), size(m.Wi, 2)
mode = m isa CuRNN ?
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
m isa CuGRU ? GRU : LSTM
r = RNNDesc{T}(mode, i, h)
(m.σ == tanh ? CUDNN.CUDNN_RNN_TANH : CUDNN.CUDNN_RNN_RELU) :
m isa CuGRU ? CUDNN.CUDNN_GRU : CUDNN.CUDNN_LSTM
r = CUDNN.RNNDesc{T}(mode, i, h)
return r
end
const descs = WeakKeyDict()
function desc(rnn)
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn))
copyparams!(rnn, d)
return d
end
@ -230,17 +47,17 @@ end
using ..Flux: @adjoint
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
y, h = forward(desc(m), x, h)
y, h = CUDNN.forward(desc(m), x, h)
return h, y
end
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
y, h = forward(desc(m), x, h)
y, h = CUDNN.forward(desc(m), x, h)
return h, y
end
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
y, h, c = forward(desc(m), x, h[1], h[2])
y, h, c = CUDNN.forward(desc(m), x, h[1], h[2])
return (h, c), y
end
@ -257,12 +74,12 @@ unbroadcast(x::AbstractArray, Δ) =
for RNN in (CuRNN, CuGRU)
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
reserve, (y, ho) = forwardTrain(desc(m), x, h)
reserve, (y, ho) = CUDNN.forwardTrain(desc(m), x, h)
(ho, y), function (Δ)
dho, dy = Δ
h_ = hBatch(x, h)
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
h_ = CUDNN.hBatch(x, h)
dx, dh = CUDNN.backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve)
dm = Ref{Any}((σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
(dm, unbroadcast(h, dh), dx)
end
@ -270,14 +87,14 @@ for RNN in (CuRNN, CuGRU)
end
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
reserve, (y, ho, co) = forwardTrain(desc(m), x, h, c)
reserve, (y, ho, co) = CUDNN.forwardTrain(desc(m), x, h, c)
((ho, co), y), function (Δ)
dhc, dy = Δ
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
h_ = hBatch(x, h)
c_ = hBatch(x, c)
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
h_ = CUDNN.hBatch(x, h)
c_ = CUDNN.hBatch(x, c)
dx, dh, dc = CUDNN.backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
(dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve)
dm = Ref{Any}((Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
(dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx)
end