hook up interface
This commit is contained in:
parent
b1c5786012
commit
9a6fcf057b
@ -147,3 +147,67 @@ function forwardInference(rnn::RNNDesc{T}, x, h, c = nothing) where T
|
||||
return y, hout, cout
|
||||
end
|
||||
end
|
||||
|
||||
# Interface
|
||||
|
||||
import ..Flux: Flux, relu
|
||||
import ..Flux.Tracker: TrackedArray
|
||||
using CUDAnative
|
||||
using CuArrays: @cuindex, cudims
|
||||
|
||||
function copy_transpose!(dst::CuArray, src::CuArray)
|
||||
function kernel(dst, src)
|
||||
I = @cuindex dst
|
||||
dst[I...] = src[reverse(I)...]
|
||||
return
|
||||
end
|
||||
blk, thr = cudims(dst)
|
||||
@cuda (blk, thr) kernel(dst, src)
|
||||
return dst
|
||||
end
|
||||
|
||||
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
||||
|
||||
function copyparams!(m::CuRNNs, d::RNNDesc)
|
||||
Wi, Wh = d.weights
|
||||
copy_transpose!(Wi, Flux.data(m.Wi))
|
||||
copy_transpose!(Wh, Flux.data(m.Wh))
|
||||
copy_transpose!(d.bias, Flux.data(m.b))
|
||||
return
|
||||
end
|
||||
|
||||
function RNNDesc(m::CuRNNs{T}) where {T}
|
||||
h, i = length(m.h), size(m.Wi, 2)
|
||||
mode = m isa CuRNN ?
|
||||
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
|
||||
m isa CuGRU ? GRU : LSTM
|
||||
r = RNNDesc{T}(mode, i, h)
|
||||
return r
|
||||
end
|
||||
|
||||
const descs = WeakKeyDict()
|
||||
|
||||
function desc(rnn)
|
||||
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
|
||||
copyparams!(rnn, d)
|
||||
return d
|
||||
end
|
||||
|
||||
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
y, h = forwardInference(desc(m), Flux.data(x), Flux.data(h))
|
||||
return h, y
|
||||
end
|
||||
|
||||
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
y, h = forwardInference(desc(m), Flux.data(x), Flux.data(h))
|
||||
return h, y
|
||||
end
|
||||
|
||||
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
y, h, c = forwardInference(desc(m), Flux.data(x), Flux.data.(h)...)
|
||||
return (h, c), y
|
||||
end
|
||||
|
@ -73,7 +73,7 @@ flip(f, xs) = reverse(f.(reverse(xs)))
|
||||
|
||||
# Vanilla RNN
|
||||
|
||||
struct RNNCell{F,A,V}
|
||||
mutable struct RNNCell{F,A,V}
|
||||
σ::F
|
||||
Wi::A
|
||||
Wh::A
|
||||
@ -112,7 +112,7 @@ RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
|
||||
|
||||
# LSTM
|
||||
|
||||
struct LSTMCell{A,V}
|
||||
mutable struct LSTMCell{A,V}
|
||||
Wi::A
|
||||
Wh::A
|
||||
b::V
|
||||
@ -161,7 +161,7 @@ LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
||||
|
||||
# GRU
|
||||
|
||||
struct GRUCell{A,V}
|
||||
mutable struct GRUCell{A,V}
|
||||
Wi::A
|
||||
Wh::A
|
||||
b::V
|
||||
|
Loading…
Reference in New Issue
Block a user