diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 55fcc3e5..5ecb8cf0 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -147,3 +147,67 @@ function forwardInference(rnn::RNNDesc{T}, x, h, c = nothing) where T return y, hout, cout end end + +# Interface + +import ..Flux: Flux, relu +import ..Flux.Tracker: TrackedArray +using CUDAnative +using CuArrays: @cuindex, cudims + +function copy_transpose!(dst::CuArray, src::CuArray) + function kernel(dst, src) + I = @cuindex dst + dst[I...] = src[reverse(I)...] + return + end + blk, thr = cudims(dst) + @cuda (blk, thr) kernel(dst, src) + return dst +end + +CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}} +CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}} +CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}} +CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}} +CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}} + +function copyparams!(m::CuRNNs, d::RNNDesc) + Wi, Wh = d.weights + copy_transpose!(Wi, Flux.data(m.Wi)) + copy_transpose!(Wh, Flux.data(m.Wh)) + copy_transpose!(d.bias, Flux.data(m.b)) + return +end + +function RNNDesc(m::CuRNNs{T}) where {T} + h, i = length(m.h), size(m.Wi, 2) + mode = m isa CuRNN ? + (m.σ == tanh ? RNN_TANH : RNN_RELU) : + m isa CuGRU ? GRU : LSTM + r = RNNDesc{T}(mode, i, h) + return r +end + +const descs = WeakKeyDict() + +function desc(rnn) + d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn)) + copyparams!(rnn, d) + return d +end + +function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} + y, h = forwardInference(desc(m), Flux.data(x), Flux.data(h)) + return h, y +end + +function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} + y, h = forwardInference(desc(m), Flux.data(x), Flux.data(h)) + return h, y +end + +function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64} + y, h, c = forwardInference(desc(m), Flux.data(x), Flux.data.(h)...) + return (h, c), y +end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 50adfc86..992b706c 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -73,7 +73,7 @@ flip(f, xs) = reverse(f.(reverse(xs))) # Vanilla RNN -struct RNNCell{F,A,V} +mutable struct RNNCell{F,A,V} σ::F Wi::A Wh::A @@ -112,7 +112,7 @@ RNN(a...; ka...) = Recur(RNNCell(a...; ka...)) # LSTM -struct LSTMCell{A,V} +mutable struct LSTMCell{A,V} Wi::A Wh::A b::V @@ -161,7 +161,7 @@ LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) # GRU -struct GRUCell{A,V} +mutable struct GRUCell{A,V} Wi::A Wh::A b::V