diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 61609b0d..35551d0f 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -87,11 +87,14 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T)) w = cuzeros(T, rnnParamSize(T, d[], input)) + (wx, wh), bias = params(w, input, hidden, ngates(mode)) + w_ = vcat(wx[:], wh[:], bias) + w[1:length(w_)] .= w_ # TODO: avoid reserve allocation here - rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[]) + rd = RNNDesc{T}(mode, input, hidden, w, (wx, wh), bias, d[]) finalizer(rd) do x @check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x) - end + end return rd end @@ -270,6 +273,9 @@ function copyparams!(m::CuRNNs, d::RNNDesc) copy_transpose!(Wi, Flux.data(m.Wi)) copy_transpose!(Wh, Flux.data(m.Wh)) copy_transpose!(d.bias, Flux.data(m.b)) + + w_ = vcat(Wi[:], Wh[:], d.bias[:]) + d.params[1:length(w_)] .= w_ return end @@ -279,6 +285,9 @@ function RNNDesc(m::CuRNNs{T}) where T (m.σ == tanh ? RNN_TANH : RNN_RELU) : m isa CuGRU ? GRU : LSTM r = RNNDesc{T}(mode, i, h) + #w_ = vcat(m.Wi[:], m.Wh[:], m.b) + #r.params[1:length(w_)] .= w_ + return r end