Flux.jl/src/cuda/curnn.jl

97 lines
3.4 KiB
Julia
Raw Normal View History

import ..Flux: Flux, relu
using CuArrays.CUDAnative
using CuArrays: @cuindex, cudims
2019-03-08 12:06:09 +00:00
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
2019-08-30 06:39:51 +00:00
function CUDNN.RNNDesc(m::CuRNNs{T}) where T
h, i = length(m.h), size(m.Wi, 2)
mode = m isa CuRNN ?
2019-08-30 06:39:51 +00:00
(m.σ == tanh ? CUDNN.CUDNN_RNN_TANH : CUDNN.CUDNN_RNN_RELU) :
m isa CuGRU ? CUDNN.CUDNN_GRU : CUDNN.CUDNN_LSTM
r = CUDNN.RNNDesc{T}(mode, i, h)
return r
end
const descs = WeakKeyDict()
function desc(rnn)
2019-08-30 06:39:51 +00:00
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn))
2019-09-17 16:22:35 +00:00
CUDNN.setweights!(d, rnn.Wi, rnn.Wh, rnn.b)
return d
end
2019-09-17 14:41:42 +00:00
import Zygote
using Zygote: @adjoint
2019-03-08 12:06:09 +00:00
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
2019-08-30 06:39:51 +00:00
y, h = CUDNN.forward(desc(m), x, h)
2019-08-19 15:56:48 +00:00
return h, y
end
2019-03-08 12:06:09 +00:00
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
2019-08-30 06:39:51 +00:00
y, h = CUDNN.forward(desc(m), x, h)
2019-08-19 15:56:48 +00:00
return h, y
end
2019-03-08 12:06:09 +00:00
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
2019-08-30 06:39:51 +00:00
y, h, c = CUDNN.forward(desc(m), x, h[1], h[2])
2019-08-19 15:56:48 +00:00
return (h, c), y
end
2019-03-08 12:06:09 +00:00
(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
2019-08-19 15:56:48 +00:00
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
unbroadcast(x::AbstractArray, Δ) =
size(x) == size(Δ) ? Δ :
length(x) == length(Δ) ? trim(x, Δ) :
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
2019-09-17 14:21:03 +00:00
coerce_cuda(x::Union{CuArray,Nothing}) = x
2019-09-17 14:49:39 +00:00
coerce_cuda(x::Tuple) = coerce_cuda.(x)
2019-09-17 14:21:03 +00:00
coerce_cuda(x) = x .+ CuArrays.fill(0)
2019-09-17 14:41:42 +00:00
function struct_grad!(cx::Zygote.Context, x, )
for f in fieldnames(typeof(x))
Zygote.accum_param(cx, getfield(x, f), getfield(, f))
end
dx = Zygote.grad_mut(cx, x)
dx[] = Zygote.accum(dx[], )
return dx
end
2019-08-19 13:39:09 +00:00
for RNN in (CuRNN, CuGRU)
2019-08-19 15:56:48 +00:00
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
2019-08-30 06:39:51 +00:00
reserve, (y, ho) = CUDNN.forwardTrain(desc(m), x, h)
2019-08-19 15:56:48 +00:00
(ho, y), function (Δ)
2019-09-17 14:49:39 +00:00
dho, dy = coerce_cuda(Δ)
2019-08-30 06:39:51 +00:00
h_ = CUDNN.hBatch(x, h)
dx, dh = CUDNN.backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve)
2019-09-17 14:41:42 +00:00
dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
2019-08-19 15:56:48 +00:00
(dm, unbroadcast(h, dh), dx)
2019-08-19 13:39:09 +00:00
end
end
end
2019-08-19 15:56:48 +00:00
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
2019-08-30 06:39:51 +00:00
reserve, (y, ho, co) = CUDNN.forwardTrain(desc(m), x, h, c)
2019-08-19 15:56:48 +00:00
((ho, co), y), function (Δ)
2019-09-17 14:49:39 +00:00
dhc, dy = coerce_cuda(Δ)
2019-08-19 15:56:48 +00:00
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
2019-08-30 06:39:51 +00:00
h_ = CUDNN.hBatch(x, h)
c_ = CUDNN.hBatch(x, c)
dx, dh, dc = CUDNN.backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
(dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve)
2019-09-17 14:41:42 +00:00
dm = struct_grad!(__context__, m, (Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
2019-08-19 15:56:48 +00:00
(dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx)
end
end