Query the worksize.

This commit is contained in:
Tim Besard 2019-08-29 17:26:10 +02:00
parent 04fce70019
commit 1e7ff4f65d

View File

@ -158,11 +158,12 @@ function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2)) dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
dh = similar(h) dh = similar(h)
dc = c == nothing ? nothing : similar(c) dc = c == nothing ? nothing : similar(c)
workspace = getworkspace(rnn, 1, yd)
CUDNN.cudnnRNNBackwardData(handle(), rnn, 1, CUDNN.cudnnRNNBackwardData(handle(), rnn, 1,
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)..., yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)..., hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
workspace[], length(workspace[]), reserve, length(reserve)) workspace, length(workspace), reserve, length(reserve))
return c == nothing ? (dx, dh) : (dx, dh, dc) return c == nothing ? (dx, dh) : (dx, dh, dc)
end end