Flux.jl/src/cuda/cudnn.jl

349 lines
13 KiB
Julia
Raw Normal View History

2018-08-20 12:08:04 +00:00
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
2018-01-26 18:29:29 +00:00
cudnnDataType, TensorDesc, FilterDesc
2018-01-24 18:45:24 +00:00
2018-08-20 12:08:04 +00:00
using LinearAlgebra
2018-08-15 10:16:12 +00:00
2018-01-24 18:45:24 +00:00
mutable struct DropoutDesc
2018-06-12 17:09:18 +00:00
ptr::Ptr{Nothing}
2018-01-24 18:45:24 +00:00
states::CuVector{UInt8}
end
2018-06-12 17:09:18 +00:00
Base.unsafe_convert(::Type{Ptr{Nothing}}, dd::DropoutDesc) = dd.ptr
2018-01-24 18:45:24 +00:00
function DropoutDesc(ρ::Real; seed::Integer=0)
d = [C_NULL]
s = Csize_t[0]
2018-06-12 17:09:18 +00:00
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Nothing}},), d)
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),libcudnn_handle[],s)
2018-01-24 18:45:24 +00:00
states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0?
desc = DropoutDesc(d[], states)
2018-06-12 17:09:18 +00:00
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
2018-01-24 18:45:24 +00:00
desc,libcudnn_handle[],ρ,states,length(states),seed)
2018-08-20 14:38:25 +00:00
finalizer(desc) do x
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
end
2018-01-24 18:45:24 +00:00
return desc
end
const RNN_RELU = 0 # Stock RNN with ReLu activation
const RNN_TANH = 1 # Stock RNN with tanh activation
const LSTM = 2 # LSTM with no peephole connections
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
const LINEAR_INPUT = 0
const SKIP_INPUT = 1
const UNIDIRECTIONAL = 0
const BIDIRECTIONAL = 1
const RNN_ALGO_STANDARD = 0
const RNN_ALGO_PERSIST_STATIC = 1
const RNN_ALGO_PERSIST_DYNAMIC = 2
2018-01-30 13:12:33 +00:00
# param layout:
# RNN: [weight, bias] × [input, hidden]
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
2018-01-31 16:56:27 +00:00
function params(w::CuVector, input, hidden, n = 1)
2018-08-20 14:38:25 +00:00
slice(offset, shape) = reshape(w[offset.+(1:prod(shape))], shape)
2018-01-31 16:56:27 +00:00
wx = slice(0, (input, hidden*n))
wh = slice(length(wx), (hidden, hidden*n))
2018-08-20 14:38:25 +00:00
bias = w[length(wx)+length(wh) .+ (1:hidden*n)]
2018-01-31 16:56:27 +00:00
(wx, wh), bias
2018-01-30 13:12:33 +00:00
end
mutable struct RNNDesc{T}
mode::Int
2018-01-26 15:28:39 +00:00
input::Int
hidden::Int
2018-01-30 13:12:33 +00:00
params::CuVector{T}
2018-01-31 16:56:27 +00:00
weights::NTuple{2,CuMatrix{T}}
bias::CuVector{T}
2018-06-12 17:09:18 +00:00
ptr::Ptr{Nothing}
2018-01-24 18:45:24 +00:00
end
2018-06-12 17:09:18 +00:00
Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr
2018-01-24 18:45:24 +00:00
2018-01-30 13:12:33 +00:00
function rnnParamSize(T, r, input)
size = Csize_t[0]
2018-06-12 17:09:18 +00:00
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint),
2018-01-30 13:12:33 +00:00
libcudnn_handle[], r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
return Int(size[])÷sizeof(T)
end
2018-02-08 02:37:55 +00:00
ngates(mode) = [1, 1, 4, 3][mode+1]
ngates(r::RNNDesc) = ngates(r.mode)
2018-01-30 13:12:33 +00:00
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
2018-01-24 18:45:24 +00:00
d = [C_NULL]
2018-06-12 17:09:18 +00:00
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Nothing}},),d)
2018-01-24 18:45:24 +00:00
2018-01-26 12:16:34 +00:00
dropoutDesc = DropoutDesc(0)
2018-01-24 18:45:24 +00:00
inputMode = LINEAR_INPUT
direction = UNIDIRECTIONAL
algo = RNN_ALGO_STANDARD
2018-06-12 17:09:18 +00:00
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint),
2018-01-30 13:12:33 +00:00
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
2018-08-03 14:19:10 +00:00
w = cuzeros(T, rnnParamSize(T, d[], input))
2018-02-02 17:48:08 +00:00
# TODO: avoid reserve allocation here
2018-02-08 10:24:59 +00:00
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
2018-08-20 14:38:25 +00:00
finalizer(rd) do x
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
end
2018-01-24 18:45:24 +00:00
return rd
end
2018-01-26 12:16:34 +00:00
2018-01-30 13:12:33 +00:00
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
2018-01-26 15:28:39 +00:00
size = Csize_t[0]
2018-06-12 17:09:18 +00:00
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}),
2018-01-30 13:12:33 +00:00
libcudnn_handle[], r, seqlen, xdesc, size)
2018-01-26 15:28:39 +00:00
return Int(size[])
end
2018-02-06 18:56:17 +00:00
const workspace = [CuVector{UInt8}(1)]
getworkspace(bytes) =
length(workspace[]) bytes ?
workspace[] :
(workspace[] = CuVector{UInt8}(bytes))
getworkspace(r::RNNDesc, seqlen, xdesc) =
2018-02-08 00:49:39 +00:00
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
2018-02-06 18:56:17 +00:00
2018-01-30 13:12:33 +00:00
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
2018-01-26 15:35:14 +00:00
size = Csize_t[0]
2018-06-12 17:09:18 +00:00
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}),
2018-01-30 13:12:33 +00:00
libcudnn_handle[], r, seqlen, xdesc, size)
2018-01-26 15:35:14 +00:00
return Int(size[])
end
2018-02-06 18:56:17 +00:00
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
2018-02-08 10:24:59 +00:00
workspace, reserve=nothing) where T
if reserve == nothing
2018-02-02 17:48:08 +00:00
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
2018-06-12 17:09:18 +00:00
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t),
2018-02-02 17:48:08 +00:00
libcudnn_handle[], rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace))
else
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
2018-06-12 17:09:18 +00:00
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
2018-02-02 17:48:08 +00:00
libcudnn_handle[], rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace), reserve, length(reserve))
end
end
2018-02-06 18:56:17 +00:00
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
2018-06-12 17:09:18 +00:00
hDesc(h::Nothing) = C_NULL, C_NULL
2018-02-08 00:49:39 +00:00
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
2018-02-06 13:29:57 +00:00
function hDesc(h::CuArray)
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
end
2018-02-08 01:06:08 +00:00
# TODO: can we just manipulate strides here?
# TODO: should use repmat, but this isn't implemented.
hBatch(x::AbstractVector, h::CuVector) = h
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
2018-02-08 10:24:59 +00:00
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
2018-02-08 01:06:08 +00:00
h = hBatch(x, h_)
c = c_ == nothing ? nothing : hBatch(x, c_)
2018-01-30 13:12:33 +00:00
@assert size(x, 1) == rnn.input
@assert size(h, 1) == rnn.hidden
@assert size(x, 2) == size(h, 2)
seqLength = 1
2018-02-06 18:56:17 +00:00
xdesc = xDesc(x)
2018-01-30 13:12:33 +00:00
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
2018-02-08 02:37:55 +00:00
ho = similar(h)
2018-02-06 18:56:17 +00:00
ydesc = xDesc(y)
workspace = getworkspace(rnn, seqLength, xdesc)
2018-02-08 10:24:59 +00:00
reserve = train == Val{true} ?
CuVector{UInt8}(rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
nothing
2018-02-06 18:56:17 +00:00
co = c == nothing ? c : similar(c)
cudnnRNNForward(rnn, seqLength,
2018-02-02 17:48:08 +00:00
xdesc, x,
2018-02-06 13:29:57 +00:00
hDesc(h)...,
hDesc(c)...,
2018-02-06 18:56:17 +00:00
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
2018-02-02 17:48:08 +00:00
ydesc, y,
2018-02-08 02:37:55 +00:00
hDesc(ho)...,
2018-02-06 18:56:17 +00:00
hDesc(co)...,
2018-02-08 10:24:59 +00:00
workspace, reserve)
result = c == nothing ? (y, ho) : (y, ho, co)
return train == Val{true} ? (reserve, result) : result
2018-02-06 18:56:17 +00:00
end
2018-02-08 10:24:59 +00:00
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T =
forward(rnn, x, h, c, Val{true})
2018-02-06 18:56:17 +00:00
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
2018-06-12 17:09:18 +00:00
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
2018-02-06 18:56:17 +00:00
libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
end
2018-02-08 10:24:59 +00:00
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
# Same as above, any more efficient way?
2018-07-18 07:01:06 +00:00
dy = dy_ isa Integer ? zero(y) : dy_
2018-02-06 18:56:17 +00:00
yd = xDesc(y)
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
dh = similar(h)
dc = c == nothing ? nothing : similar(c)
cudnnRNNBackwardData(rnn, 1,
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
2018-02-08 10:24:59 +00:00
workspace[], reserve)
2018-02-06 18:56:17 +00:00
return c == nothing ? (dx, dh) : (dx, dh, dc)
end
2018-02-08 10:24:59 +00:00
backwardData(rnn, y, dy, dho, hx, reserve) =
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
2018-02-06 18:56:17 +00:00
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
workspace, reserve) where T
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
2018-06-12 17:09:18 +00:00
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
Ptr{Ptr{Nothing}}, Ptr{T}, #x
Ptr{Nothing}, Ptr{T}, #hx
Ptr{Ptr{Nothing}}, Ptr{T}, #y
Ptr{Nothing}, Csize_t, #ws
Ptr{Nothing}, Ptr{T}, #dw
Ptr{Nothing}, Csize_t), #rs
2018-02-06 18:56:17 +00:00
libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y,
workspace, length(workspace), dwd, dw, reserve, length(reserve))
end
2018-02-08 10:24:59 +00:00
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
2018-07-18 07:01:06 +00:00
dw = zero(rnn.params)
2018-02-06 18:56:17 +00:00
cudnnRNNBackwardWeights(rnn, 1,
xDesc(x), x, hDesc(h)..., xDesc(y), y,
FilterDesc(T, (1, 1, length(dw))), dw,
2018-02-08 10:24:59 +00:00
workspace[], reserve)
2018-02-08 02:37:55 +00:00
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
2018-01-26 12:16:34 +00:00
end
2018-02-02 16:19:56 +00:00
# Interface
import ..Flux: Flux, relu
2018-02-28 22:07:35 +00:00
import ..Tracker: TrackedArray
2018-08-20 12:08:04 +00:00
using .CuArrays.CUDAnative
using .CuArrays: @cuindex, cudims
2018-02-02 16:19:56 +00:00
2018-08-15 10:16:12 +00:00
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
2018-02-02 16:19:56 +00:00
function kernel(dst, src)
I = @cuindex dst
dst[I...] = src[reverse(I)...]
return
end
blk, thr = cudims(dst)
2018-08-15 10:16:12 +00:00
@cuda blocks=blk threads=thr kernel(dst, src)
2018-02-02 16:19:56 +00:00
return dst
end
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
function copyparams!(m::CuRNNs, d::RNNDesc)
Wi, Wh = d.weights
copy_transpose!(Wi, Flux.data(m.Wi))
copy_transpose!(Wh, Flux.data(m.Wh))
copy_transpose!(d.bias, Flux.data(m.b))
return
end
2018-02-06 18:56:17 +00:00
function RNNDesc(m::CuRNNs{T}) where T
2018-02-02 16:19:56 +00:00
h, i = length(m.h), size(m.Wi, 2)
mode = m isa CuRNN ?
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
m isa CuGRU ? GRU : LSTM
r = RNNDesc{T}(mode, i, h)
return r
end
const descs = WeakKeyDict()
function desc(rnn)
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
copyparams!(rnn, d)
return d
end
2018-07-10 17:16:37 +00:00
import Flux.Tracker
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
2018-02-08 00:49:39 +00:00
2018-02-02 17:48:08 +00:00
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
2018-02-02 16:19:56 +00:00
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
2018-02-08 00:49:39 +00:00
result = istrain(m, h, x) ?
2018-07-10 17:16:37 +00:00
track(m, x, h, m.Wi, m.Wh, m.b) :
2018-02-08 00:49:39 +00:00
forward(desc(m), x, h)
return result[2], result[1]
2018-02-02 16:19:56 +00:00
end
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
2018-02-08 00:49:39 +00:00
result = istrain(m, h, x) ?
2018-07-10 17:16:37 +00:00
track(m, x, h, m.Wi, m.Wh, m.b) :
2018-02-08 00:49:39 +00:00
forward(desc(m), x, h)
return result[2], result[1]
2018-02-02 16:19:56 +00:00
end
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
2018-02-08 00:49:39 +00:00
result = istrain(m, h, x) ?
2018-07-10 17:16:37 +00:00
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
2018-02-08 00:49:39 +00:00
forward(desc(m), x, h[1], h[2])
return (result[2], result[3]), result[1]
end
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
2018-07-10 17:16:37 +00:00
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data(x), data(h))
result, function (Δ)
y, ho = result
dy, dho = Δ
h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
2018-10-05 13:14:24 +00:00
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
2018-02-08 00:49:39 +00:00
end
end
2018-07-10 17:16:37 +00:00
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
result, function (Δ)
y, ho = result
dy, dho, dco = Δ
h_ = hBatch(x, data(h))
c_ = hBatch(x, data(c))
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN,
2018-10-05 13:14:24 +00:00
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
2018-08-20 12:08:04 +00:00
transpose(dWi), transpose(dWh), db))
2018-07-10 17:16:37 +00:00
end
end