diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index c37d031c..bbd4e122 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -158,11 +158,12 @@ function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2)) dh = similar(h) dc = c == nothing ? nothing : similar(c) + workspace = getworkspace(rnn, 1, yd) CUDNN.cudnnRNNBackwardData(handle(), rnn, 1, yd, y, yd, dy, hDesc(dho)..., hDesc(dco)..., FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)..., - workspace[], length(workspace[]), reserve, length(reserve)) + workspace, length(workspace), reserve, length(reserve)) return c == nothing ? (dx, dh) : (dx, dh, dc) end