Query the worksize.

This commit is contained in:
Tim Besard 2019-08-29 17:26:10 +02:00
parent 04fce70019
commit 1e7ff4f65d

View File

@ -158,11 +158,12 @@ function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
dh = similar(h)
dc = c == nothing ? nothing : similar(c)
workspace = getworkspace(rnn, 1, yd)
CUDNN.cudnnRNNBackwardData(handle(), rnn, 1,
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
workspace[], length(workspace[]), reserve, length(reserve))
workspace, length(workspace), reserve, length(reserve))
return c == nothing ? (dx, dh) : (dx, dh, dc)
end