2018-06-17 06:17:49 +00:00
|
|
|
|
import ..Flux: Flux, relu
|
2019-08-27 07:33:15 +00:00
|
|
|
|
using CuArrays.CUDAnative
|
|
|
|
|
using CuArrays: @cuindex, cudims
|
2018-06-17 06:17:49 +00:00
|
|
|
|
|
2019-03-08 12:06:09 +00:00
|
|
|
|
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
|
|
|
|
|
CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
|
|
|
|
|
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
|
2018-06-17 06:17:49 +00:00
|
|
|
|
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
|
|
|
|
|
2019-08-30 06:39:51 +00:00
|
|
|
|
function CUDNN.RNNDesc(m::CuRNNs{T}) where T
|
2018-06-17 06:17:49 +00:00
|
|
|
|
h, i = length(m.h), size(m.Wi, 2)
|
|
|
|
|
mode = m isa CuRNN ?
|
2019-08-30 06:39:51 +00:00
|
|
|
|
(m.σ == tanh ? CUDNN.CUDNN_RNN_TANH : CUDNN.CUDNN_RNN_RELU) :
|
|
|
|
|
m isa CuGRU ? CUDNN.CUDNN_GRU : CUDNN.CUDNN_LSTM
|
|
|
|
|
r = CUDNN.RNNDesc{T}(mode, i, h)
|
2018-06-17 06:17:49 +00:00
|
|
|
|
return r
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
const descs = WeakKeyDict()
|
|
|
|
|
|
|
|
|
|
function desc(rnn)
|
2019-08-30 06:39:51 +00:00
|
|
|
|
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn))
|
2019-09-17 16:22:35 +00:00
|
|
|
|
CUDNN.setweights!(d, rnn.Wi, rnn.Wh, rnn.b)
|
2018-06-17 06:17:49 +00:00
|
|
|
|
return d
|
|
|
|
|
end
|
|
|
|
|
|
2019-09-17 14:41:42 +00:00
|
|
|
|
import Zygote
|
|
|
|
|
using Zygote: @adjoint
|
2018-06-17 06:17:49 +00:00
|
|
|
|
|
2019-03-08 12:06:09 +00:00
|
|
|
|
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
2019-08-30 06:39:51 +00:00
|
|
|
|
y, h′ = CUDNN.forward(desc(m), x, h)
|
2019-08-19 15:56:48 +00:00
|
|
|
|
return h′, y
|
2018-06-17 06:17:49 +00:00
|
|
|
|
end
|
|
|
|
|
|
2019-03-08 12:06:09 +00:00
|
|
|
|
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
2019-08-30 06:39:51 +00:00
|
|
|
|
y, h′ = CUDNN.forward(desc(m), x, h)
|
2019-08-19 15:56:48 +00:00
|
|
|
|
return h′, y
|
2018-06-17 06:17:49 +00:00
|
|
|
|
end
|
|
|
|
|
|
2019-03-08 12:06:09 +00:00
|
|
|
|
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
2019-08-30 06:39:51 +00:00
|
|
|
|
y, h′, c′ = CUDNN.forward(desc(m), x, h[1], h[2])
|
2019-08-19 15:56:48 +00:00
|
|
|
|
return (h′, c′), y
|
2018-06-17 06:17:49 +00:00
|
|
|
|
end
|
|
|
|
|
|
2019-03-08 12:06:09 +00:00
|
|
|
|
(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
|
|
|
|
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
|
|
|
|
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
2018-06-17 06:17:49 +00:00
|
|
|
|
|
2019-08-19 15:56:48 +00:00
|
|
|
|
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
|
|
|
|
|
|
|
|
|
unbroadcast(x::AbstractArray, Δ) =
|
|
|
|
|
size(x) == size(Δ) ? Δ :
|
|
|
|
|
length(x) == length(Δ) ? trim(x, Δ) :
|
|
|
|
|
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
|
|
|
|
|
2019-09-17 14:21:03 +00:00
|
|
|
|
coerce_cuda(x::Union{CuArray,Nothing}) = x
|
2019-09-17 14:49:39 +00:00
|
|
|
|
coerce_cuda(x::Tuple) = coerce_cuda.(x)
|
2019-09-17 14:21:03 +00:00
|
|
|
|
|
|
|
|
|
coerce_cuda(x) = x .+ CuArrays.fill(0)
|
|
|
|
|
|
2019-09-17 14:41:42 +00:00
|
|
|
|
function struct_grad!(cx::Zygote.Context, x, x̄)
|
|
|
|
|
for f in fieldnames(typeof(x))
|
|
|
|
|
Zygote.accum_param(cx, getfield(x, f), getfield(x̄, f))
|
|
|
|
|
end
|
|
|
|
|
dx = Zygote.grad_mut(cx, x)
|
|
|
|
|
dx[] = Zygote.accum(dx[], x̄)
|
|
|
|
|
return dx
|
|
|
|
|
end
|
|
|
|
|
|
2019-08-19 13:39:09 +00:00
|
|
|
|
for RNN in (CuRNN, CuGRU)
|
2019-08-19 15:56:48 +00:00
|
|
|
|
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
2019-08-30 06:39:51 +00:00
|
|
|
|
reserve, (y, ho) = CUDNN.forwardTrain(desc(m), x, h)
|
2019-08-19 15:56:48 +00:00
|
|
|
|
(ho, y), function (Δ)
|
2019-09-17 14:49:39 +00:00
|
|
|
|
dho, dy = coerce_cuda(Δ)
|
2019-08-30 06:39:51 +00:00
|
|
|
|
h_ = CUDNN.hBatch(x, h)
|
|
|
|
|
dx, dh = CUDNN.backwardData(descs[m], y, dy, dho, h_, reserve)
|
|
|
|
|
(dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve)
|
2019-09-17 14:41:42 +00:00
|
|
|
|
dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
|
2019-08-19 15:56:48 +00:00
|
|
|
|
(dm, unbroadcast(h, dh), dx)
|
2019-08-19 13:39:09 +00:00
|
|
|
|
end
|
2018-06-17 06:17:49 +00:00
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
2019-08-19 15:56:48 +00:00
|
|
|
|
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
2019-08-30 06:39:51 +00:00
|
|
|
|
reserve, (y, ho, co) = CUDNN.forwardTrain(desc(m), x, h, c)
|
2019-08-19 15:56:48 +00:00
|
|
|
|
((ho, co), y), function (Δ)
|
2019-09-17 14:49:39 +00:00
|
|
|
|
dhc, dy = coerce_cuda(Δ)
|
2019-08-19 15:56:48 +00:00
|
|
|
|
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
|
2019-08-30 06:39:51 +00:00
|
|
|
|
h_ = CUDNN.hBatch(x, h)
|
|
|
|
|
c_ = CUDNN.hBatch(x, c)
|
|
|
|
|
dx, dh, dc = CUDNN.backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
|
|
|
|
(dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve)
|
2019-09-17 14:41:42 +00:00
|
|
|
|
dm = struct_grad!(__context__, m, (Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
|
2019-08-19 15:56:48 +00:00
|
|
|
|
(dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx)
|
2018-07-17 04:10:20 +00:00
|
|
|
|
end
|
2018-06-17 06:17:49 +00:00
|
|
|
|
end
|