Move functionality over to CuArrays.
This commit is contained in:
parent
1e7ff4f65d
commit
4942d7fcfd
@ -1,149 +1,5 @@
|
|||||||
using CuArrays.CUDNN: handle, TensorDesc, FilterDesc
|
import ..Flux: data
|
||||||
|
import CuArrays.CUDNN: batchnorm, ∇batchnorm
|
||||||
import CuArrays.CUDAdrv: CU_NULL
|
|
||||||
|
|
||||||
using LinearAlgebra
|
|
||||||
|
|
||||||
mutable struct DropoutDesc
|
|
||||||
ptr::Ptr{Nothing}
|
|
||||||
states::CuVector{UInt8}
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.unsafe_convert(::Type{Ptr{Nothing}}, dd::DropoutDesc) = dd.ptr
|
|
||||||
|
|
||||||
function DropoutDesc(ρ::Real; seed::Integer=0)
|
|
||||||
d = [C_NULL]
|
|
||||||
s = Csize_t[0]
|
|
||||||
CUDNN.cudnnCreateDropoutDescriptor(d)
|
|
||||||
CUDNN.cudnnDropoutGetStatesSize(handle(), s)
|
|
||||||
states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0?
|
|
||||||
desc = DropoutDesc(d[], states)
|
|
||||||
CUDNN.cudnnSetDropoutDescriptor(desc, handle(), ρ, states, length(states), seed)
|
|
||||||
finalizer(desc) do x
|
|
||||||
CUDNN.cudnnDestroyDropoutDescriptor(x)
|
|
||||||
end
|
|
||||||
return desc
|
|
||||||
end
|
|
||||||
|
|
||||||
@inline _wsize(y) = (map(_ -> 1, size(y)[1:end-2])..., size(y)[end-1], 1)
|
|
||||||
|
|
||||||
@inline _reddims(y) = (collect(1:ndims(y)-2)..., ndims(y))
|
|
||||||
|
|
||||||
mutable struct BNCache
|
|
||||||
mean
|
|
||||||
ivar
|
|
||||||
end
|
|
||||||
|
|
||||||
BNCache() = BNCache(nothing, nothing)
|
|
||||||
|
|
||||||
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
|
|
||||||
# so reshape a 2D Tensor into 4D
|
|
||||||
batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2},
|
|
||||||
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
|
||||||
cache = nothing, alpha = T(1), beta = T(0),
|
|
||||||
eps = T(1e-5), training = true) where T<:Union{Float32, Float64} =
|
|
||||||
dropdims(batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), running_mean, running_var, momentum,
|
|
||||||
cache = cache, alpha = alpha, beta = beta, eps = eps, training = training), dims = (1, 2))
|
|
||||||
|
|
||||||
function batchnorm(g::CuArray{T}, b::CuArray{T}, x::Union{CuArray{T, 4},CuArray{T,5}},
|
|
||||||
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
|
||||||
cache = nothing, alpha = T(1), beta = T(0),
|
|
||||||
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
|
||||||
y = similar(x)
|
|
||||||
cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache,
|
|
||||||
alpha = alpha, beta = beta, eps = eps, training = training)
|
|
||||||
y
|
|
||||||
end
|
|
||||||
|
|
||||||
function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
|
|
||||||
running_mean::CuArray{T}, running_var::CuArray{T},
|
|
||||||
momentum; cache = nothing,
|
|
||||||
alpha = T(1), beta = T(0),
|
|
||||||
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
|
||||||
dims = _wsize(x)
|
|
||||||
if eps < CUDNN.CUDNN_BN_MIN_EPSILON
|
|
||||||
# warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", CUDNN.CUDNN_BN_MIN_EPSILON)
|
|
||||||
eps = CUDNN.CUDNN_BN_MIN_EPSILON
|
|
||||||
end
|
|
||||||
xd = TensorDesc(x)
|
|
||||||
yd = TensorDesc(y)
|
|
||||||
gd = TensorDesc(T, dims)
|
|
||||||
|
|
||||||
if training
|
|
||||||
|
|
||||||
if cache !== nothing
|
|
||||||
mean = zeros(CuArray{T}, dims...)
|
|
||||||
ivar = ones(CuArray{T}, dims...)
|
|
||||||
else
|
|
||||||
mean = CU_NULL
|
|
||||||
ivar = CU_NULL
|
|
||||||
end
|
|
||||||
|
|
||||||
CUDNN.cudnnBatchNormalizationForwardTraining(handle(), CUDNN.CUDNN_BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), xd, x, yd, y, gd, g, b, momentum, running_mean, running_var, eps, mean, ivar)
|
|
||||||
|
|
||||||
if cache !== nothing
|
|
||||||
cache.mean = mean
|
|
||||||
cache.ivar = ivar
|
|
||||||
end
|
|
||||||
else
|
|
||||||
CUDNN.cudnnBatchNormalizationForwardInference(handle(), CUDNN.CUDNN_BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), xd, x, yd, y, gd, g, b, running_mean, running_var, eps)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T, 2},
|
|
||||||
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
|
||||||
cache = nothing, eps = T(1e-5), alpha = T(1),
|
|
||||||
beta = T(0), training = true) where T<:Union{Float32, Float64}
|
|
||||||
dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1),
|
|
||||||
size(dy, 2)), running_mean, running_var, momentum, cache = cache, eps = eps,
|
|
||||||
alpha = alpha, beta = beta, training = training)
|
|
||||||
(dg, db, dropdims(dx, dims = (1, 2)))
|
|
||||||
end
|
|
||||||
|
|
||||||
function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
|
||||||
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
|
||||||
cache = nothing, eps = T(1e-5), alpha = T(1),
|
|
||||||
beta = T(0), training = true) where T<:Union{Float32, Float64}
|
|
||||||
dg = similar(g)
|
|
||||||
db = similar(b)
|
|
||||||
dx = similar(x)
|
|
||||||
cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum),
|
|
||||||
training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
|
|
||||||
(dg, db, dx)
|
|
||||||
end
|
|
||||||
|
|
||||||
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
|
||||||
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
|
||||||
running_mean::CuArray{T}, running_var::CuArray{T},
|
|
||||||
momentum; cache = nothing, eps = T(1e-5),
|
|
||||||
alpha = T(1), beta = T(0),
|
|
||||||
dalpha = T(1), dbeta = T(0), training = true) where T<:Union{Float32, Float64}
|
|
||||||
if training
|
|
||||||
xd = TensorDesc(x)
|
|
||||||
dyd = TensorDesc(dy)
|
|
||||||
dxd = TensorDesc(dx)
|
|
||||||
gd = TensorDesc(T, _wsize(x))
|
|
||||||
if cache !== nothing
|
|
||||||
mean, ivar = cache.mean, cache.ivar
|
|
||||||
info("mean and ivar are fetched from the cache")
|
|
||||||
else
|
|
||||||
mean, ivar = CU_NULL, CU_NULL
|
|
||||||
end
|
|
||||||
|
|
||||||
if eps < CUDNN.CUDNN_BN_MIN_EPSILON
|
|
||||||
eps = CUDNN.CUDNN_BN_MIN_EPSILON
|
|
||||||
end
|
|
||||||
|
|
||||||
CUDNN.cudnnBatchNormalizationBackward(handle(), CUDNN.CUDNN_BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), Ref(T(dalpha)), Ref(T(dbeta)), xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
|
|
||||||
else
|
|
||||||
ivar = 1 ./ sqrt.(reshape(running_var, _wsize(x)) .+ eps)
|
|
||||||
dx .= dy .* reshape(g, _wsize(x)) .* ivar
|
|
||||||
dg .= squeeze(sum(dy .* (x .- reshape(running_mean, _wsize(x))) .* ivar, _reddims(dy)), dims = (1,2,4))
|
|
||||||
db .= squeeze(sum(dy, _reddims(dy)), dims = (1,2,4))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
# Flux Interface
|
|
||||||
|
|
||||||
(BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
|
(BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
|
||||||
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = Flux.istraining()))
|
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = Flux.istraining()))
|
||||||
|
@ -1,190 +1,7 @@
|
|||||||
using CuArrays.CUDNN: handle, cudnnDataType, TensorDesc, FilterDesc
|
|
||||||
|
|
||||||
import CuArrays.CUDAdrv: CU_NULL
|
|
||||||
|
|
||||||
using LinearAlgebra
|
|
||||||
|
|
||||||
const RNN_RELU = 0 # Stock RNN with ReLu activation
|
|
||||||
const RNN_TANH = 1 # Stock RNN with tanh activation
|
|
||||||
const LSTM = 2 # LSTM with no peephole connections
|
|
||||||
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
|
|
||||||
|
|
||||||
const LINEAR_INPUT = 0
|
|
||||||
const SKIP_INPUT = 1
|
|
||||||
|
|
||||||
const UNIDIRECTIONAL = 0
|
|
||||||
const BIDIRECTIONAL = 1
|
|
||||||
|
|
||||||
const RNN_ALGO_STANDARD = 0
|
|
||||||
const RNN_ALGO_PERSIST_STATIC = 1
|
|
||||||
const RNN_ALGO_PERSIST_DYNAMIC = 2
|
|
||||||
|
|
||||||
# param layout:
|
|
||||||
# RNN: [weight, bias] × [input, hidden]
|
|
||||||
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
|
|
||||||
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
|
||||||
|
|
||||||
function params(w::CuVector, input, hidden, n = 1)
|
|
||||||
slice(offset, shape) = reshape(view(w, offset.+(1:prod(shape))), shape)
|
|
||||||
wx = slice(0, (input, hidden*n))
|
|
||||||
wh = slice(length(wx), (hidden, hidden*n))
|
|
||||||
bias = view(w, length(wx)+length(wh) .+ (1:hidden*n))
|
|
||||||
(wx, wh), bias
|
|
||||||
end
|
|
||||||
|
|
||||||
mutable struct RNNDesc{T}
|
|
||||||
mode::Int
|
|
||||||
input::Int
|
|
||||||
hidden::Int
|
|
||||||
params::CuVector{T}
|
|
||||||
weights::NTuple{2,CuMatrix{T}}
|
|
||||||
bias::CuVector{T}
|
|
||||||
ptr::Ptr{Nothing}
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr
|
|
||||||
|
|
||||||
function rnnParamSize(T, r, input)
|
|
||||||
size = Csize_t[0]
|
|
||||||
CUDNN.cudnnGetRNNParamsSize(handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
|
|
||||||
return Int(size[])÷sizeof(T)
|
|
||||||
end
|
|
||||||
|
|
||||||
ngates(mode) = [1, 1, 4, 3][mode+1]
|
|
||||||
ngates(r::RNNDesc) = ngates(r.mode)
|
|
||||||
|
|
||||||
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
|
||||||
d = [C_NULL]
|
|
||||||
CUDNN.cudnnCreateRNNDescriptor(d)
|
|
||||||
|
|
||||||
dropoutDesc = DropoutDesc(0)
|
|
||||||
inputMode = LINEAR_INPUT
|
|
||||||
direction = UNIDIRECTIONAL
|
|
||||||
algo = RNN_ALGO_STANDARD
|
|
||||||
CUDNN.cudnnSetRNNDescriptor_v6(handle(),d[],hidden,layers,dropoutDesc,CUDNN.cudnnRNNInputMode_t(inputMode),CUDNN.cudnnDirectionMode_t(direction),CUDNN.cudnnRNNMode_t(mode),CUDNN.cudnnRNNAlgo_t(algo),cudnnDataType(T))
|
|
||||||
|
|
||||||
w =CuArrays.zeros(T, rnnParamSize(T, d[], input))
|
|
||||||
# TODO: avoid reserve allocation here
|
|
||||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
|
||||||
finalizer(rd) do x
|
|
||||||
CUDNN.cudnnDestroyRNNDescriptor(x)
|
|
||||||
end
|
|
||||||
return rd
|
|
||||||
end
|
|
||||||
|
|
||||||
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
|
|
||||||
size = Csize_t[0]
|
|
||||||
CUDNN.cudnnGetRNNWorkspaceSize(handle(), r, seqlen, xdesc, size)
|
|
||||||
return Int(size[])
|
|
||||||
end
|
|
||||||
|
|
||||||
const workspace = [CuVector{UInt8}(undef, 1)]
|
|
||||||
|
|
||||||
getworkspace(bytes) =
|
|
||||||
length(workspace[]) ≥ bytes ?
|
|
||||||
workspace[] :
|
|
||||||
(workspace[] = CuVector{UInt8}(undef, bytes))
|
|
||||||
|
|
||||||
getworkspace(r::RNNDesc, seqlen, xdesc) =
|
|
||||||
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
|
|
||||||
|
|
||||||
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
|
||||||
size = Csize_t[0]
|
|
||||||
CUDNN.cudnnGetRNNTrainingReserveSize(handle(), r, seqlen, xdesc, size)
|
|
||||||
return Int(size[])
|
|
||||||
end
|
|
||||||
|
|
||||||
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
|
||||||
workspace, reserve=nothing) where T
|
|
||||||
if reserve == nothing
|
|
||||||
CUDNN.cudnnRNNForwardInference(handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y,
|
|
||||||
hod, ho, cod, co, workspace, length(workspace))
|
|
||||||
else
|
|
||||||
CUDNN.cudnnRNNForwardTraining(handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y,
|
|
||||||
hod, ho, cod, co, workspace, length(workspace), reserve, length(reserve))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
|
||||||
|
|
||||||
hDesc(h::Nothing) = C_NULL, CU_NULL
|
|
||||||
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
|
||||||
function hDesc(h::CuArray)
|
|
||||||
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
|
||||||
end
|
|
||||||
|
|
||||||
# TODO: can we just manipulate strides here?
|
|
||||||
# TODO: should use repmat, but this isn't implemented.
|
|
||||||
hBatch(x::AbstractVector, h::CuVector) = h
|
|
||||||
hBatch(x::AbstractMatrix, h::CuVector) = h .*CuArrays.ones(1, size(x, 2))
|
|
||||||
hBatch(x::AbstractMatrix, h::CuMatrix) = h .*CuArrays.ones(1, size(h,2) == 1 ? size(x,2) : 1)
|
|
||||||
|
|
||||||
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
|
|
||||||
h = hBatch(x, h_)
|
|
||||||
c = c_ == nothing ? nothing : hBatch(x, c_)
|
|
||||||
@assert size(x, 1) == rnn.input
|
|
||||||
@assert size(h, 1) == rnn.hidden
|
|
||||||
@assert size(x, 2) == size(h, 2)
|
|
||||||
seqLength = 1
|
|
||||||
xdesc = xDesc(x)
|
|
||||||
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
|
|
||||||
ho = similar(h)
|
|
||||||
ydesc = xDesc(y)
|
|
||||||
workspace = getworkspace(rnn, seqLength, xdesc)
|
|
||||||
reserve = train == Val{true} ?
|
|
||||||
CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
|
|
||||||
nothing
|
|
||||||
co = c == nothing ? c : similar(c)
|
|
||||||
cudnnRNNForward(rnn, seqLength,
|
|
||||||
xdesc, x,
|
|
||||||
hDesc(h)...,
|
|
||||||
hDesc(c)...,
|
|
||||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
|
||||||
ydesc, y,
|
|
||||||
hDesc(ho)...,
|
|
||||||
hDesc(co)...,
|
|
||||||
workspace, reserve)
|
|
||||||
result = c == nothing ? (y, ho) : (y, ho, co)
|
|
||||||
return train == Val{true} ? (reserve, result) : result
|
|
||||||
end
|
|
||||||
|
|
||||||
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T =
|
|
||||||
forward(rnn, x, h, c, Val{true})
|
|
||||||
|
|
||||||
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
|
|
||||||
# Same as above, any more efficient way?
|
|
||||||
dy = dy_ isa Integer ? zero(y) : dy_
|
|
||||||
yd = xDesc(y)
|
|
||||||
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
|
||||||
dh = similar(h)
|
|
||||||
dc = c == nothing ? nothing : similar(c)
|
|
||||||
workspace = getworkspace(rnn, 1, yd)
|
|
||||||
CUDNN.cudnnRNNBackwardData(handle(), rnn, 1,
|
|
||||||
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
|
|
||||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
|
||||||
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
|
|
||||||
workspace, length(workspace), reserve, length(reserve))
|
|
||||||
return c == nothing ? (dx, dh) : (dx, dh, dc)
|
|
||||||
end
|
|
||||||
|
|
||||||
backwardData(rnn, y, dy, dho, hx, reserve) =
|
|
||||||
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
|
|
||||||
|
|
||||||
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
|
|
||||||
dw = zero(rnn.params)
|
|
||||||
CUDNN.cudnnRNNBackwardWeights(handle(), rnn, 1,
|
|
||||||
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
|
||||||
workspace[], length(workspace[]),
|
|
||||||
FilterDesc(T, (1, 1, length(dw))), dw,
|
|
||||||
reserve, length(reserve))
|
|
||||||
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
|
|
||||||
end
|
|
||||||
|
|
||||||
# Interface
|
|
||||||
|
|
||||||
import ..Flux: Flux, relu
|
import ..Flux: Flux, relu
|
||||||
using CuArrays.CUDAnative
|
using CuArrays.CUDAnative
|
||||||
using CuArrays: @cuindex, cudims
|
using CuArrays: @cuindex, cudims
|
||||||
|
using LinearAlgebra
|
||||||
|
|
||||||
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
||||||
function kernel(dst, src)
|
function kernel(dst, src)
|
||||||
@ -202,7 +19,7 @@ CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
|
|||||||
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
|
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
|
||||||
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
||||||
|
|
||||||
function copyparams!(m::CuRNNs, d::RNNDesc)
|
function copyparams!(m::CuRNNs, d::CUDNN.RNNDesc)
|
||||||
Wi, Wh = d.weights
|
Wi, Wh = d.weights
|
||||||
copy_transpose!(Wi, m.Wi)
|
copy_transpose!(Wi, m.Wi)
|
||||||
copy_transpose!(Wh, m.Wh)
|
copy_transpose!(Wh, m.Wh)
|
||||||
@ -210,19 +27,19 @@ function copyparams!(m::CuRNNs, d::RNNDesc)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
function RNNDesc(m::CuRNNs{T}) where T
|
function CUDNN.RNNDesc(m::CuRNNs{T}) where T
|
||||||
h, i = length(m.h), size(m.Wi, 2)
|
h, i = length(m.h), size(m.Wi, 2)
|
||||||
mode = m isa CuRNN ?
|
mode = m isa CuRNN ?
|
||||||
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
|
(m.σ == tanh ? CUDNN.CUDNN_RNN_TANH : CUDNN.CUDNN_RNN_RELU) :
|
||||||
m isa CuGRU ? GRU : LSTM
|
m isa CuGRU ? CUDNN.CUDNN_GRU : CUDNN.CUDNN_LSTM
|
||||||
r = RNNDesc{T}(mode, i, h)
|
r = CUDNN.RNNDesc{T}(mode, i, h)
|
||||||
return r
|
return r
|
||||||
end
|
end
|
||||||
|
|
||||||
const descs = WeakKeyDict()
|
const descs = WeakKeyDict()
|
||||||
|
|
||||||
function desc(rnn)
|
function desc(rnn)
|
||||||
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
|
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn))
|
||||||
copyparams!(rnn, d)
|
copyparams!(rnn, d)
|
||||||
return d
|
return d
|
||||||
end
|
end
|
||||||
@ -230,17 +47,17 @@ end
|
|||||||
using ..Flux: @adjoint
|
using ..Flux: @adjoint
|
||||||
|
|
||||||
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
y, h′ = forward(desc(m), x, h)
|
y, h′ = CUDNN.forward(desc(m), x, h)
|
||||||
return h′, y
|
return h′, y
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
y, h′ = forward(desc(m), x, h)
|
y, h′ = CUDNN.forward(desc(m), x, h)
|
||||||
return h′, y
|
return h′, y
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
y, h′, c′ = forward(desc(m), x, h[1], h[2])
|
y, h′, c′ = CUDNN.forward(desc(m), x, h[1], h[2])
|
||||||
return (h′, c′), y
|
return (h′, c′), y
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -257,12 +74,12 @@ unbroadcast(x::AbstractArray, Δ) =
|
|||||||
|
|
||||||
for RNN in (CuRNN, CuGRU)
|
for RNN in (CuRNN, CuGRU)
|
||||||
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
reserve, (y, ho) = forwardTrain(desc(m), x, h)
|
reserve, (y, ho) = CUDNN.forwardTrain(desc(m), x, h)
|
||||||
(ho, y), function (Δ)
|
(ho, y), function (Δ)
|
||||||
dho, dy = Δ
|
dho, dy = Δ
|
||||||
h_ = hBatch(x, h)
|
h_ = CUDNN.hBatch(x, h)
|
||||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
dx, dh = CUDNN.backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
(dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve)
|
||||||
dm = Ref{Any}((σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
|
dm = Ref{Any}((σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
|
||||||
(dm, unbroadcast(h, dh), dx)
|
(dm, unbroadcast(h, dh), dx)
|
||||||
end
|
end
|
||||||
@ -270,14 +87,14 @@ for RNN in (CuRNN, CuGRU)
|
|||||||
end
|
end
|
||||||
|
|
||||||
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
reserve, (y, ho, co) = forwardTrain(desc(m), x, h, c)
|
reserve, (y, ho, co) = CUDNN.forwardTrain(desc(m), x, h, c)
|
||||||
((ho, co), y), function (Δ)
|
((ho, co), y), function (Δ)
|
||||||
dhc, dy = Δ
|
dhc, dy = Δ
|
||||||
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
|
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
|
||||||
h_ = hBatch(x, h)
|
h_ = CUDNN.hBatch(x, h)
|
||||||
c_ = hBatch(x, c)
|
c_ = CUDNN.hBatch(x, c)
|
||||||
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
dx, dh, dc = CUDNN.backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
(dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve)
|
||||||
dm = Ref{Any}((Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
|
dm = Ref{Any}((Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
|
||||||
(dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx)
|
(dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx)
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user