wip
This commit is contained in:
parent
baff20514d
commit
405cab895e
|
@ -22,351 +22,39 @@ function DropoutDesc(ρ::Real; seed::Integer=0)
|
|||
return desc
|
||||
end
|
||||
|
||||
const RNN_RELU = 0 # Stock RNN with ReLu activation
|
||||
const RNN_TANH = 1 # Stock RNN with tanh activation
|
||||
const LSTM = 2 # LSTM with no peephole connections
|
||||
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
|
||||
const BATCHNORM_SPATIAL = 1
|
||||
const BATCHNORM_ACTIVATION = 0
|
||||
const BATCHNORM_MIN_EPS = 1e-5
|
||||
|
||||
const LINEAR_INPUT = 0
|
||||
const SKIP_INPUT = 1
|
||||
bnshape(x::NTuple{4}) = x
|
||||
bnshape(x::Union{NTuple{1},NTuple{2},NTuple{3}}) = bnshape((1,x...))
|
||||
bnshape(x::AbstractArray) = bnshape(size(x))
|
||||
|
||||
const UNIDIRECTIONAL = 0
|
||||
const BIDIRECTIONAL = 1
|
||||
|
||||
const RNN_ALGO_STANDARD = 0
|
||||
const RNN_ALGO_PERSIST_STATIC = 1
|
||||
const RNN_ALGO_PERSIST_DYNAMIC = 2
|
||||
|
||||
# param layout:
|
||||
# RNN: [weight, bias] × [input, hidden]
|
||||
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
|
||||
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
||||
|
||||
function params(w::CuVector, input, hidden, n = 1)
|
||||
slice(offset, shape) = reshape(w[offset+(1:prod(shape))], shape)
|
||||
wx = slice(0, (input, hidden*n))
|
||||
wh = slice(length(wx), (hidden, hidden*n))
|
||||
bias = w[length(wx)+length(wh) + (1:hidden*n)]
|
||||
(wx, wh), bias
|
||||
function batchnorm(x::CuArray{T}) where T<:Union{Float32,Float64}
|
||||
y = similar(x)
|
||||
sh = bnshape(x)
|
||||
td_x = TensorDesc(T, sh)
|
||||
td_p = TensorDesc(T, (1,1,sh[3],1))
|
||||
# @check ccall((:cudnnBatchNormalizationForwardTraining,libcudnn),cudnnStatus_t,
|
||||
# (Ptr{Void}, UInt32,
|
||||
# Ptr{T}, Ptr{T}, #alpha and beta
|
||||
# Ptr{Void}, Ptr{T}, #xdesc and x
|
||||
# Ptr{Void}, Ptr{T}, #ydesc and y
|
||||
# Ptr{Void}, Ptr{T}, Ptr{T}, #desc, weight and bias
|
||||
# Cdouble, Ptr{T}, Ptr{T}, #Decay factor, Running mean and Running var
|
||||
# Cdouble, # eps
|
||||
# Ptr{T}, Ptr{T}), #Cached mean and ivar
|
||||
# libcudnn_handle[], BATCHNORM_SPATIAL,
|
||||
# Ref(T(1)), Ref(T(0)),
|
||||
# TensorDesc(x), x, #x
|
||||
# TensorDesc(y), y, #y
|
||||
# TensorDesc(g), g, b, #params
|
||||
# momentum, running_mean, running_var,
|
||||
# eps, mean, ivar)
|
||||
end
|
||||
|
||||
mutable struct RNNDesc{T}
|
||||
mode::Int
|
||||
input::Int
|
||||
hidden::Int
|
||||
params::CuVector{T}
|
||||
weights::NTuple{2,CuMatrix{T}}
|
||||
bias::CuVector{T}
|
||||
ptr::Ptr{Void}
|
||||
end
|
||||
batchnorm(cu(randn(10,5)))
|
||||
|
||||
Base.unsafe_convert(::Type{Ptr{Void}}, d::RNNDesc) = d.ptr
|
||||
TensorDesc(Float32, (1,5,1,1))
|
||||
|
||||
function rnnParamSize(T, r, input)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint),
|
||||
libcudnn_handle[], r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
|
||||
return Int(size[])÷sizeof(T)
|
||||
end
|
||||
|
||||
ngates(mode) = [1, 1, 4, 3][mode+1]
|
||||
ngates(r::RNNDesc) = ngates(r.mode)
|
||||
|
||||
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
||||
d = [C_NULL]
|
||||
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d)
|
||||
|
||||
dropoutDesc = DropoutDesc(0)
|
||||
inputMode = LINEAR_INPUT
|
||||
direction = UNIDIRECTIONAL
|
||||
algo = RNN_ALGO_STANDARD
|
||||
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint),
|
||||
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
||||
|
||||
w = cuzeros(T, rnnParamSize(T, d[], input))
|
||||
# TODO: avoid reserve allocation here
|
||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
||||
finalizer(rd, x ->
|
||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
||||
return rd
|
||||
end
|
||||
|
||||
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Ptr{Ptr{Void}},Ptr{Csize_t}),
|
||||
libcudnn_handle[], r, seqlen, xdesc, size)
|
||||
return Int(size[])
|
||||
end
|
||||
|
||||
const workspace = [CuVector{UInt8}(1)]
|
||||
|
||||
getworkspace(bytes) =
|
||||
length(workspace[]) ≥ bytes ?
|
||||
workspace[] :
|
||||
(workspace[] = CuVector{UInt8}(bytes))
|
||||
|
||||
getworkspace(r::RNNDesc, seqlen, xdesc) =
|
||||
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
|
||||
|
||||
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Void}, Ptr{Void}, Cint, Ptr{Ptr{Void}}, Ptr{Csize_t}),
|
||||
libcudnn_handle[], r, seqlen, xdesc, size)
|
||||
return Int(size[])
|
||||
end
|
||||
|
||||
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, reserve=nothing) where T
|
||||
if reserve == nothing
|
||||
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint,
|
||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Csize_t),
|
||||
libcudnn_handle[], rnn, seqlen,
|
||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, length(workspace))
|
||||
else
|
||||
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint,
|
||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t),
|
||||
libcudnn_handle[], rnn, seqlen,
|
||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, length(workspace), reserve, length(reserve))
|
||||
end
|
||||
end
|
||||
|
||||
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
||||
|
||||
hDesc(h::Void) = C_NULL, C_NULL
|
||||
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
||||
function hDesc(h::CuArray)
|
||||
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
||||
end
|
||||
|
||||
# TODO: can we just manipulate strides here?
|
||||
# TODO: should use repmat, but this isn't implemented.
|
||||
hBatch(x::AbstractVector, h::CuVector) = h
|
||||
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
|
||||
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
|
||||
|
||||
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
|
||||
h = hBatch(x, h_)
|
||||
c = c_ == nothing ? nothing : hBatch(x, c_)
|
||||
@assert size(x, 1) == rnn.input
|
||||
@assert size(h, 1) == rnn.hidden
|
||||
@assert size(x, 2) == size(h, 2)
|
||||
seqLength = 1
|
||||
xdesc = xDesc(x)
|
||||
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
|
||||
ho = similar(h)
|
||||
ydesc = xDesc(y)
|
||||
workspace = getworkspace(rnn, seqLength, xdesc)
|
||||
reserve = train == Val{true} ?
|
||||
CuVector{UInt8}(rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
|
||||
nothing
|
||||
co = c == nothing ? c : similar(c)
|
||||
cudnnRNNForward(rnn, seqLength,
|
||||
xdesc, x,
|
||||
hDesc(h)...,
|
||||
hDesc(c)...,
|
||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
ydesc, y,
|
||||
hDesc(ho)...,
|
||||
hDesc(co)...,
|
||||
workspace, reserve)
|
||||
result = c == nothing ? (y, ho) : (y, ho, co)
|
||||
return train == Val{true} ? (reserve, result) : result
|
||||
end
|
||||
|
||||
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T =
|
||||
forward(rnn, x, h, c, Val{true})
|
||||
|
||||
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
||||
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint,
|
||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void},
|
||||
Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t),
|
||||
libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
||||
end
|
||||
|
||||
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
|
||||
# Same as above, any more efficient way?
|
||||
dy = dy_ isa Integer ? zeros(y) : dy_
|
||||
yd = xDesc(y)
|
||||
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
||||
dh = similar(h)
|
||||
dc = c == nothing ? nothing : similar(c)
|
||||
cudnnRNNBackwardData(rnn, 1,
|
||||
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
|
||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
|
||||
workspace[], reserve)
|
||||
return c == nothing ? (dx, dh) : (dx, dh, dc)
|
||||
end
|
||||
|
||||
backwardData(rnn, y, dy, dho, hx, reserve) =
|
||||
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
|
||||
|
||||
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
|
||||
workspace, reserve) where T
|
||||
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint, # handle, rnnDesc, seqLength
|
||||
Ptr{Ptr{Void}}, Ptr{T}, #x
|
||||
Ptr{Void}, Ptr{T}, #hx
|
||||
Ptr{Ptr{Void}}, Ptr{T}, #y
|
||||
Ptr{Void}, Csize_t, #ws
|
||||
Ptr{Void}, Ptr{T}, #dw
|
||||
Ptr{Void}, Csize_t), #rs
|
||||
libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y,
|
||||
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
||||
end
|
||||
|
||||
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
|
||||
dw = zeros(rnn.params)
|
||||
cudnnRNNBackwardWeights(rnn, 1,
|
||||
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
||||
FilterDesc(T, (1, 1, length(dw))), dw,
|
||||
workspace[], reserve)
|
||||
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
|
||||
end
|
||||
|
||||
# Interface
|
||||
|
||||
import ..Flux: Flux, relu
|
||||
import ..Tracker: TrackedArray
|
||||
using CUDAnative
|
||||
using CuArrays: @cuindex, cudims
|
||||
|
||||
function copy_transpose!(dst::CuArray, src::CuArray)
|
||||
function kernel(dst, src)
|
||||
I = @cuindex dst
|
||||
dst[I...] = src[reverse(I)...]
|
||||
return
|
||||
end
|
||||
blk, thr = cudims(dst)
|
||||
@cuda (blk, thr) kernel(dst, src)
|
||||
return dst
|
||||
end
|
||||
|
||||
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
||||
|
||||
function copyparams!(m::CuRNNs, d::RNNDesc)
|
||||
Wi, Wh = d.weights
|
||||
copy_transpose!(Wi, Flux.data(m.Wi))
|
||||
copy_transpose!(Wh, Flux.data(m.Wh))
|
||||
copy_transpose!(d.bias, Flux.data(m.b))
|
||||
return
|
||||
end
|
||||
|
||||
function RNNDesc(m::CuRNNs{T}) where T
|
||||
h, i = length(m.h), size(m.Wi, 2)
|
||||
mode = m isa CuRNN ?
|
||||
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
|
||||
m isa CuGRU ? GRU : LSTM
|
||||
r = RNNDesc{T}(mode, i, h)
|
||||
return r
|
||||
end
|
||||
|
||||
const descs = WeakKeyDict()
|
||||
|
||||
function desc(rnn)
|
||||
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
|
||||
copyparams!(rnn, d)
|
||||
return d
|
||||
end
|
||||
|
||||
import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast
|
||||
|
||||
mutable struct RNNCall{R}
|
||||
rnn::R
|
||||
reserve::CuVector{UInt8}
|
||||
RNNCall{R}(rnn::R) where R = new(rnn)
|
||||
end
|
||||
|
||||
RNNCall(rnn) = RNNCall{typeof(rnn)}(rnn)
|
||||
|
||||
function (c::RNNCall)(args...)
|
||||
rs, result = forwardTrain(desc(c.rnn), args...)
|
||||
c.reserve = rs
|
||||
return result
|
||||
end
|
||||
|
||||
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
||||
|
||||
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(RNNCall(m), x, h) :
|
||||
forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
end
|
||||
|
||||
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(RNNCall(m), x, h) :
|
||||
forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
end
|
||||
|
||||
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(RNNCall(m), x, h[1], h[2]) :
|
||||
forward(desc(m), x, h[1], h[2])
|
||||
return (result[2], result[3]), result[1]
|
||||
end
|
||||
|
||||
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
|
||||
function accum_transpose!(dst::CuArray, src::CuArray)
|
||||
function kernel(dst, src)
|
||||
I = @cuindex dst
|
||||
dst[I...] += src[reverse(I)...]
|
||||
return
|
||||
end
|
||||
blk, thr = cudims(dst)
|
||||
@cuda (blk, thr) kernel(dst, src)
|
||||
return dst
|
||||
end
|
||||
|
||||
function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h)
|
||||
y, ho = y_
|
||||
dy, dho = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve)
|
||||
@back(x, dx)
|
||||
@back(h, unbroadcast(h, dh))
|
||||
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
|
||||
# We don't have to make this assumption, it's just slightly more complex.
|
||||
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
|
||||
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
|
||||
end
|
||||
|
||||
function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c)
|
||||
y, ho, co = y_
|
||||
dy, dho, dco = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
c_ = hBatch(x, data(c))
|
||||
dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve)
|
||||
@back(x, dx)
|
||||
@back(h, unbroadcast(h, dh))
|
||||
@back(c, unbroadcast(h, dc))
|
||||
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
|
||||
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
|
||||
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
|
||||
end
|
||||
methods(TensorDesc)
|
||||
|
|
|
@ -0,0 +1,348 @@
|
|||
const RNN_RELU = 0 # Stock RNN with ReLu activation
|
||||
const RNN_TANH = 1 # Stock RNN with tanh activation
|
||||
const LSTM = 2 # LSTM with no peephole connections
|
||||
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
|
||||
|
||||
const LINEAR_INPUT = 0
|
||||
const SKIP_INPUT = 1
|
||||
|
||||
const UNIDIRECTIONAL = 0
|
||||
const BIDIRECTIONAL = 1
|
||||
|
||||
const RNN_ALGO_STANDARD = 0
|
||||
const RNN_ALGO_PERSIST_STATIC = 1
|
||||
const RNN_ALGO_PERSIST_DYNAMIC = 2
|
||||
|
||||
# param layout:
|
||||
# RNN: [weight, bias] × [input, hidden]
|
||||
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
|
||||
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
||||
|
||||
function params(w::CuVector, input, hidden, n = 1)
|
||||
slice(offset, shape) = reshape(w[offset+(1:prod(shape))], shape)
|
||||
wx = slice(0, (input, hidden*n))
|
||||
wh = slice(length(wx), (hidden, hidden*n))
|
||||
bias = w[length(wx)+length(wh) + (1:hidden*n)]
|
||||
(wx, wh), bias
|
||||
end
|
||||
|
||||
mutable struct RNNDesc{T}
|
||||
mode::Int
|
||||
input::Int
|
||||
hidden::Int
|
||||
params::CuVector{T}
|
||||
weights::NTuple{2,CuMatrix{T}}
|
||||
bias::CuVector{T}
|
||||
ptr::Ptr{Void}
|
||||
end
|
||||
|
||||
Base.unsafe_convert(::Type{Ptr{Void}}, d::RNNDesc) = d.ptr
|
||||
|
||||
function rnnParamSize(T, r, input)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint),
|
||||
libcudnn_handle[], r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
|
||||
return Int(size[])÷sizeof(T)
|
||||
end
|
||||
|
||||
ngates(mode) = [1, 1, 4, 3][mode+1]
|
||||
ngates(r::RNNDesc) = ngates(r.mode)
|
||||
|
||||
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
||||
d = [C_NULL]
|
||||
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d)
|
||||
|
||||
dropoutDesc = DropoutDesc(0)
|
||||
inputMode = LINEAR_INPUT
|
||||
direction = UNIDIRECTIONAL
|
||||
algo = RNN_ALGO_STANDARD
|
||||
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint),
|
||||
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
||||
|
||||
w = cuzeros(T, rnnParamSize(T, d[], input))
|
||||
# TODO: avoid reserve allocation here
|
||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
||||
finalizer(rd, x ->
|
||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
||||
return rd
|
||||
end
|
||||
|
||||
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Ptr{Ptr{Void}},Ptr{Csize_t}),
|
||||
libcudnn_handle[], r, seqlen, xdesc, size)
|
||||
return Int(size[])
|
||||
end
|
||||
|
||||
const workspace = [CuVector{UInt8}(1)]
|
||||
|
||||
getworkspace(bytes) =
|
||||
length(workspace[]) ≥ bytes ?
|
||||
workspace[] :
|
||||
(workspace[] = CuVector{UInt8}(bytes))
|
||||
|
||||
getworkspace(r::RNNDesc, seqlen, xdesc) =
|
||||
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
|
||||
|
||||
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Void}, Ptr{Void}, Cint, Ptr{Ptr{Void}}, Ptr{Csize_t}),
|
||||
libcudnn_handle[], r, seqlen, xdesc, size)
|
||||
return Int(size[])
|
||||
end
|
||||
|
||||
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, reserve=nothing) where T
|
||||
if reserve == nothing
|
||||
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint,
|
||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Csize_t),
|
||||
libcudnn_handle[], rnn, seqlen,
|
||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, length(workspace))
|
||||
else
|
||||
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint,
|
||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t),
|
||||
libcudnn_handle[], rnn, seqlen,
|
||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, length(workspace), reserve, length(reserve))
|
||||
end
|
||||
end
|
||||
|
||||
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
||||
|
||||
hDesc(h::Void) = C_NULL, C_NULL
|
||||
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
||||
function hDesc(h::CuArray)
|
||||
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
||||
end
|
||||
|
||||
# TODO: can we just manipulate strides here?
|
||||
# TODO: should use repmat, but this isn't implemented.
|
||||
hBatch(x::AbstractVector, h::CuVector) = h
|
||||
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
|
||||
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
|
||||
|
||||
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
|
||||
h = hBatch(x, h_)
|
||||
c = c_ == nothing ? nothing : hBatch(x, c_)
|
||||
@assert size(x, 1) == rnn.input
|
||||
@assert size(h, 1) == rnn.hidden
|
||||
@assert size(x, 2) == size(h, 2)
|
||||
seqLength = 1
|
||||
xdesc = xDesc(x)
|
||||
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
|
||||
ho = similar(h)
|
||||
ydesc = xDesc(y)
|
||||
workspace = getworkspace(rnn, seqLength, xdesc)
|
||||
reserve = train == Val{true} ?
|
||||
CuVector{UInt8}(rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
|
||||
nothing
|
||||
co = c == nothing ? c : similar(c)
|
||||
cudnnRNNForward(rnn, seqLength,
|
||||
xdesc, x,
|
||||
hDesc(h)...,
|
||||
hDesc(c)...,
|
||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
ydesc, y,
|
||||
hDesc(ho)...,
|
||||
hDesc(co)...,
|
||||
workspace, reserve)
|
||||
result = c == nothing ? (y, ho) : (y, ho, co)
|
||||
return train == Val{true} ? (reserve, result) : result
|
||||
end
|
||||
|
||||
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T =
|
||||
forward(rnn, x, h, c, Val{true})
|
||||
|
||||
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
||||
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint,
|
||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void},
|
||||
Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t),
|
||||
libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
||||
end
|
||||
|
||||
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
|
||||
# Same as above, any more efficient way?
|
||||
dy = dy_ isa Integer ? zeros(y) : dy_
|
||||
yd = xDesc(y)
|
||||
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
||||
dh = similar(h)
|
||||
dc = c == nothing ? nothing : similar(c)
|
||||
cudnnRNNBackwardData(rnn, 1,
|
||||
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
|
||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
|
||||
workspace[], reserve)
|
||||
return c == nothing ? (dx, dh) : (dx, dh, dc)
|
||||
end
|
||||
|
||||
backwardData(rnn, y, dy, dho, hx, reserve) =
|
||||
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
|
||||
|
||||
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
|
||||
workspace, reserve) where T
|
||||
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint, # handle, rnnDesc, seqLength
|
||||
Ptr{Ptr{Void}}, Ptr{T}, #x
|
||||
Ptr{Void}, Ptr{T}, #hx
|
||||
Ptr{Ptr{Void}}, Ptr{T}, #y
|
||||
Ptr{Void}, Csize_t, #ws
|
||||
Ptr{Void}, Ptr{T}, #dw
|
||||
Ptr{Void}, Csize_t), #rs
|
||||
libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y,
|
||||
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
||||
end
|
||||
|
||||
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
|
||||
dw = zeros(rnn.params)
|
||||
cudnnRNNBackwardWeights(rnn, 1,
|
||||
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
||||
FilterDesc(T, (1, 1, length(dw))), dw,
|
||||
workspace[], reserve)
|
||||
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
|
||||
end
|
||||
|
||||
# Interface
|
||||
|
||||
import ..Flux: Flux, relu
|
||||
import ..Tracker: TrackedArray
|
||||
using CUDAnative
|
||||
using CuArrays: @cuindex, cudims
|
||||
|
||||
function copy_transpose!(dst::CuArray, src::CuArray)
|
||||
function kernel(dst, src)
|
||||
I = @cuindex dst
|
||||
dst[I...] = src[reverse(I)...]
|
||||
return
|
||||
end
|
||||
blk, thr = cudims(dst)
|
||||
@cuda (blk, thr) kernel(dst, src)
|
||||
return dst
|
||||
end
|
||||
|
||||
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
||||
|
||||
function copyparams!(m::CuRNNs, d::RNNDesc)
|
||||
Wi, Wh = d.weights
|
||||
copy_transpose!(Wi, Flux.data(m.Wi))
|
||||
copy_transpose!(Wh, Flux.data(m.Wh))
|
||||
copy_transpose!(d.bias, Flux.data(m.b))
|
||||
return
|
||||
end
|
||||
|
||||
function RNNDesc(m::CuRNNs{T}) where T
|
||||
h, i = length(m.h), size(m.Wi, 2)
|
||||
mode = m isa CuRNN ?
|
||||
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
|
||||
m isa CuGRU ? GRU : LSTM
|
||||
r = RNNDesc{T}(mode, i, h)
|
||||
return r
|
||||
end
|
||||
|
||||
const descs = WeakKeyDict()
|
||||
|
||||
function desc(rnn)
|
||||
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
|
||||
copyparams!(rnn, d)
|
||||
return d
|
||||
end
|
||||
|
||||
import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast
|
||||
|
||||
mutable struct RNNCall{R}
|
||||
rnn::R
|
||||
reserve::CuVector{UInt8}
|
||||
RNNCall{R}(rnn::R) where R = new(rnn)
|
||||
end
|
||||
|
||||
RNNCall(rnn) = RNNCall{typeof(rnn)}(rnn)
|
||||
|
||||
function (c::RNNCall)(args...)
|
||||
rs, result = forwardTrain(desc(c.rnn), args...)
|
||||
c.reserve = rs
|
||||
return result
|
||||
end
|
||||
|
||||
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
||||
|
||||
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(RNNCall(m), x, h) :
|
||||
forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
end
|
||||
|
||||
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(RNNCall(m), x, h) :
|
||||
forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
end
|
||||
|
||||
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(RNNCall(m), x, h[1], h[2]) :
|
||||
forward(desc(m), x, h[1], h[2])
|
||||
return (result[2], result[3]), result[1]
|
||||
end
|
||||
|
||||
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
|
||||
function accum_transpose!(dst::CuArray, src::CuArray)
|
||||
function kernel(dst, src)
|
||||
I = @cuindex dst
|
||||
dst[I...] += src[reverse(I)...]
|
||||
return
|
||||
end
|
||||
blk, thr = cudims(dst)
|
||||
@cuda (blk, thr) kernel(dst, src)
|
||||
return dst
|
||||
end
|
||||
|
||||
function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h)
|
||||
y, ho = y_
|
||||
dy, dho = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve)
|
||||
@back(x, dx)
|
||||
@back(h, unbroadcast(h, dh))
|
||||
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
|
||||
# We don't have to make this assumption, it's just slightly more complex.
|
||||
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
|
||||
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
|
||||
end
|
||||
|
||||
function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c)
|
||||
y, ho, co = y_
|
||||
dy, dho, dco = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
c_ = hBatch(x, data(c))
|
||||
dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve)
|
||||
@back(x, dx)
|
||||
@back(h, unbroadcast(h, dh))
|
||||
@back(c, unbroadcast(h, dc))
|
||||
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
|
||||
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
|
||||
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
|
||||
end
|
Loading…
Reference in New Issue