diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index e292ac1c..c9728a48 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -130,6 +130,11 @@ function cudnnRNNForward(::Type{T}, rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, end end +hDesc(h::Void) = C_NULL, C_NULL +function hDesc(h::CuArray) + TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h +end + function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; train = false) where T @assert size(x, 1) == rnn.input @assert size(h, 1) == rnn.hidden @@ -140,29 +145,20 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; tra ydesc = [TensorDesc(T, (1, size(y, 1), size(y, 2)))] workspace = CuVector{UInt8}(rnnWorkspaceSize(rnn, seqLength, xdesc)) # TODO: reuse this reserve = train ? getreserve(rnn, seqLength, xdesc) : rnn.reserve - if c ≠ nothing - @assert size(c, 1) == rnn.hidden - @assert size(c, 2) == size(h, 2) - cptr = c - cdesc = TensorDesc(T, (size(c, 1), size(c, 2), 1)) - cout = similar(c) - coutdesc = TensorDesc(T, (size(cout, 1), size(cout, 2), 1)) - else - cptr = cdesc = cout = coutdesc = C_NULL - end + cy = c == nothing ? c : similar(c) cudnnRNNForward(T, rnn, seqLength, xdesc, x, - TensorDesc(T, (size(h, 1), size(h, 2), 1)), h, - cdesc, cptr, + hDesc(h)..., + hDesc(c)..., TensorDesc(T, (1, 1, length(rnn.params))), rnn.params, ydesc, y, C_NULL, C_NULL, # hout - coutdesc, cout, + hDesc(cy)..., workspace, reserve, train = train) if c == nothing return y, y else - return y, y, cout + return y, y, cy end end