Query the worksize.
This commit is contained in:
parent
04fce70019
commit
1e7ff4f65d
@ -158,11 +158,12 @@ function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
|
||||
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
||||
dh = similar(h)
|
||||
dc = c == nothing ? nothing : similar(c)
|
||||
workspace = getworkspace(rnn, 1, yd)
|
||||
CUDNN.cudnnRNNBackwardData(handle(), rnn, 1,
|
||||
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
|
||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
|
||||
workspace[], length(workspace[]), reserve, length(reserve))
|
||||
workspace, length(workspace), reserve, length(reserve))
|
||||
return c == nothing ? (dx, dh) : (dx, dh, dc)
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user