Flux.jl/src/cuda/curnn.jl

285 lines
9.5 KiB
Julia
Raw Normal View History

using CuArrays.CUDNN: handle, cudnnDataType, TensorDesc, FilterDesc
import CuArrays.CUDAdrv: CU_NULL
2018-09-11 11:02:14 +00:00
using LinearAlgebra
const RNN_RELU = 0 # Stock RNN with ReLu activation
const RNN_TANH = 1 # Stock RNN with tanh activation
const LSTM = 2 # LSTM with no peephole connections
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
const LINEAR_INPUT = 0
const SKIP_INPUT = 1
const UNIDIRECTIONAL = 0
const BIDIRECTIONAL = 1
const RNN_ALGO_STANDARD = 0
const RNN_ALGO_PERSIST_STATIC = 1
const RNN_ALGO_PERSIST_DYNAMIC = 2
# param layout:
# RNN: [weight, bias] × [input, hidden]
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
function params(w::CuVector, input, hidden, n = 1)
2018-10-23 16:23:29 +00:00
slice(offset, shape) = reshape(view(w, offset.+(1:prod(shape))), shape)
wx = slice(0, (input, hidden*n))
wh = slice(length(wx), (hidden, hidden*n))
2018-10-23 16:23:29 +00:00
bias = view(w, length(wx)+length(wh) .+ (1:hidden*n))
(wx, wh), bias
end
mutable struct RNNDesc{T}
mode::Int
input::Int
hidden::Int
params::CuVector{T}
weights::NTuple{2,CuMatrix{T}}
bias::CuVector{T}
2018-08-11 09:53:40 +00:00
ptr::Ptr{Nothing}
end
2018-08-11 09:53:40 +00:00
Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr
function rnnParamSize(T, r, input)
size = Csize_t[0]
CUDNN.cudnnGetRNNParamsSize(handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
return Int(size[])÷sizeof(T)
end
ngates(mode) = [1, 1, 4, 3][mode+1]
ngates(r::RNNDesc) = ngates(r.mode)
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
d = [C_NULL]
CUDNN.cudnnCreateRNNDescriptor(d)
dropoutDesc = DropoutDesc(0)
inputMode = LINEAR_INPUT
direction = UNIDIRECTIONAL
algo = RNN_ALGO_STANDARD
CUDNN.cudnnSetRNNDescriptor_v6(handle(),d[],hidden,layers,dropoutDesc,CUDNN.cudnnRNNInputMode_t(inputMode),CUDNN.cudnnDirectionMode_t(direction),CUDNN.cudnnRNNMode_t(mode),CUDNN.cudnnRNNAlgo_t(algo),cudnnDataType(T))
w =CuArrays.zeros(T, rnnParamSize(T, d[], input))
# TODO: avoid reserve allocation here
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
2018-09-11 10:14:07 +00:00
finalizer(rd) do x
CUDNN.cudnnDestroyRNNDescriptor(x)
2018-09-11 10:14:07 +00:00
end
return rd
end
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
size = Csize_t[0]
CUDNN.cudnnGetRNNWorkspaceSize(handle(), r, seqlen, xdesc, size)
return Int(size[])
end
2018-11-14 16:48:57 +00:00
const workspace = [CuVector{UInt8}(undef, 1)]
getworkspace(bytes) =
length(workspace[]) bytes ?
workspace[] :
2018-11-14 16:48:57 +00:00
(workspace[] = CuVector{UInt8}(undef, bytes))
getworkspace(r::RNNDesc, seqlen, xdesc) =
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
size = Csize_t[0]
CUDNN.cudnnGetRNNTrainingReserveSize(handle(), r, seqlen, xdesc, size)
return Int(size[])
end
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, reserve=nothing) where T
if reserve == nothing
CUDNN.cudnnRNNForwardInference(handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y,
hod, ho, cod, co, workspace, length(workspace))
else
CUDNN.cudnnRNNForwardTraining(handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y,
hod, ho, cod, co, workspace, length(workspace), reserve, length(reserve))
end
end
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
hDesc(h::Nothing) = C_NULL, CU_NULL
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
function hDesc(h::CuArray)
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
end
# TODO: can we just manipulate strides here?
# TODO: should use repmat, but this isn't implemented.
hBatch(x::AbstractVector, h::CuVector) = h
hBatch(x::AbstractMatrix, h::CuVector) = h .*CuArrays.ones(1, size(x, 2))
hBatch(x::AbstractMatrix, h::CuMatrix) = h .*CuArrays.ones(1, size(h,2) == 1 ? size(x,2) : 1)
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
h = hBatch(x, h_)
c = c_ == nothing ? nothing : hBatch(x, c_)
@assert size(x, 1) == rnn.input
@assert size(h, 1) == rnn.hidden
@assert size(x, 2) == size(h, 2)
seqLength = 1
xdesc = xDesc(x)
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
ho = similar(h)
ydesc = xDesc(y)
workspace = getworkspace(rnn, seqLength, xdesc)
reserve = train == Val{true} ?
2018-11-14 16:48:57 +00:00
CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
nothing
co = c == nothing ? c : similar(c)
cudnnRNNForward(rnn, seqLength,
xdesc, x,
hDesc(h)...,
hDesc(c)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
ydesc, y,
hDesc(ho)...,
hDesc(co)...,
workspace, reserve)
result = c == nothing ? (y, ho) : (y, ho, co)
return train == Val{true} ? (reserve, result) : result
end
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T =
forward(rnn, x, h, c, Val{true})
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
# Same as above, any more efficient way?
2018-08-11 09:53:40 +00:00
dy = dy_ isa Integer ? zero(y) : dy_
yd = xDesc(y)
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
dh = similar(h)
dc = c == nothing ? nothing : similar(c)
2019-08-29 15:26:10 +00:00
workspace = getworkspace(rnn, 1, yd)
CUDNN.cudnnRNNBackwardData(handle(), rnn, 1,
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
2019-08-29 15:26:10 +00:00
workspace, length(workspace), reserve, length(reserve))
return c == nothing ? (dx, dh) : (dx, dh, dc)
end
backwardData(rnn, y, dy, dho, hx, reserve) =
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
2018-08-11 09:53:40 +00:00
dw = zero(rnn.params)
CUDNN.cudnnRNNBackwardWeights(handle(), rnn, 1,
xDesc(x), x, hDesc(h)..., xDesc(y), y,
workspace[], length(workspace[]),
FilterDesc(T, (1, 1, length(dw))), dw,
reserve, length(reserve))
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
end
# Interface
import ..Flux: Flux, relu
using CuArrays.CUDAnative
using CuArrays: @cuindex, cudims
2018-09-11 10:14:07 +00:00
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src)
I = @cuindex dst
dst[I...] = src[reverse(I)...]
return
end
2018-11-27 23:44:07 +00:00
blk, thr = cudims(dst)
@cuda blocks=blk threads=thr kernel(dst, src)
return dst
end
2019-03-08 12:06:09 +00:00
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
function copyparams!(m::CuRNNs, d::RNNDesc)
Wi, Wh = d.weights
2019-08-19 14:09:32 +00:00
copy_transpose!(Wi, m.Wi)
copy_transpose!(Wh, m.Wh)
copy_transpose!(d.bias, m.b)
return
end
function RNNDesc(m::CuRNNs{T}) where T
h, i = length(m.h), size(m.Wi, 2)
mode = m isa CuRNN ?
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
m isa CuGRU ? GRU : LSTM
r = RNNDesc{T}(mode, i, h)
return r
end
const descs = WeakKeyDict()
function desc(rnn)
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
copyparams!(rnn, d)
return d
end
2019-08-19 14:22:50 +00:00
using ..Flux: @adjoint
2019-03-08 12:06:09 +00:00
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
2019-08-19 15:56:48 +00:00
y, h = forward(desc(m), x, h)
return h, y
end
2019-03-08 12:06:09 +00:00
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
2019-08-19 15:56:48 +00:00
y, h = forward(desc(m), x, h)
return h, y
end
2019-03-08 12:06:09 +00:00
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
2019-08-19 15:56:48 +00:00
y, h, c = forward(desc(m), x, h[1], h[2])
return (h, c), y
end
2019-03-08 12:06:09 +00:00
(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
2019-08-19 15:56:48 +00:00
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
unbroadcast(x::AbstractArray, Δ) =
size(x) == size(Δ) ? Δ :
length(x) == length(Δ) ? trim(x, Δ) :
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
2019-08-19 13:39:09 +00:00
for RNN in (CuRNN, CuGRU)
2019-08-19 15:56:48 +00:00
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
reserve, (y, ho) = forwardTrain(desc(m), x, h)
(ho, y), function (Δ)
dho, dy = Δ
2019-08-19 13:39:09 +00:00
h_ = hBatch(x, h)
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
2019-08-19 15:56:48 +00:00
dm = Ref{Any}((σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
(dm, unbroadcast(h, dh), dx)
2019-08-19 13:39:09 +00:00
end
end
end
2019-08-19 15:56:48 +00:00
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
reserve, (y, ho, co) = forwardTrain(desc(m), x, h, c)
((ho, co), y), function (Δ)
dhc, dy = Δ
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
2019-03-08 12:13:58 +00:00
h_ = hBatch(x, h)
c_ = hBatch(x, c)
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
2019-03-08 12:13:58 +00:00
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
2019-08-19 15:56:48 +00:00
dm = Ref{Any}((Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
(dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx)
end
end