fixed params getting zero
This commit is contained in:
parent
73385b5dbd
commit
4d1a6c305b
@ -87,11 +87,14 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
||||
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
||||
|
||||
w = cuzeros(T, rnnParamSize(T, d[], input))
|
||||
(wx, wh), bias = params(w, input, hidden, ngates(mode))
|
||||
w_ = vcat(wx[:], wh[:], bias)
|
||||
w[1:length(w_)] .= w_
|
||||
# TODO: avoid reserve allocation here
|
||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
||||
rd = RNNDesc{T}(mode, input, hidden, w, (wx, wh), bias, d[])
|
||||
finalizer(rd) do x
|
||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||
end
|
||||
end
|
||||
return rd
|
||||
end
|
||||
|
||||
@ -270,6 +273,9 @@ function copyparams!(m::CuRNNs, d::RNNDesc)
|
||||
copy_transpose!(Wi, Flux.data(m.Wi))
|
||||
copy_transpose!(Wh, Flux.data(m.Wh))
|
||||
copy_transpose!(d.bias, Flux.data(m.b))
|
||||
|
||||
w_ = vcat(Wi[:], Wh[:], d.bias[:])
|
||||
d.params[1:length(w_)] .= w_
|
||||
return
|
||||
end
|
||||
|
||||
@ -279,6 +285,9 @@ function RNNDesc(m::CuRNNs{T}) where T
|
||||
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
|
||||
m isa CuGRU ? GRU : LSTM
|
||||
r = RNNDesc{T}(mode, i, h)
|
||||
#w_ = vcat(m.Wi[:], m.Wh[:], m.b)
|
||||
#r.params[1:length(w_)] .= w_
|
||||
|
||||
return r
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user