gc nuclear option
This commit is contained in:
parent
3c1ac84676
commit
5e51f75f45
|
@ -4,6 +4,14 @@ using CuArrays.CUDNN: @check, cudnnStatus_t, cudnnTensorDescriptor_t,
|
|||
|
||||
import CuArrays.CUDAdrv: CuPtr, CU_NULL
|
||||
|
||||
macro nogc(ex)
|
||||
quote
|
||||
st = GC.enable(false)
|
||||
$(esc(ex))
|
||||
GC.enable(st)
|
||||
end
|
||||
end
|
||||
|
||||
using LinearAlgebra
|
||||
|
||||
const RNN_RELU = 0 # Stock RNN with ReLu activation
|
||||
|
@ -153,15 +161,17 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, t
|
|||
CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
|
||||
nothing
|
||||
co = c == nothing ? c : similar(c)
|
||||
cudnnRNNForward(rnn, seqLength,
|
||||
xdesc, x,
|
||||
hDesc(h)...,
|
||||
hDesc(c)...,
|
||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
ydesc, y,
|
||||
hDesc(ho)...,
|
||||
hDesc(co)...,
|
||||
workspace, reserve)
|
||||
@nogc begin
|
||||
cudnnRNNForward(rnn, seqLength,
|
||||
xdesc, x,
|
||||
hDesc(h)...,
|
||||
hDesc(c)...,
|
||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
ydesc, y,
|
||||
hDesc(ho)...,
|
||||
hDesc(co)...,
|
||||
workspace, reserve)
|
||||
end
|
||||
result = c == nothing ? (y, ho) : (y, ho, co)
|
||||
return train == Val{true} ? (reserve, result) : result
|
||||
end
|
||||
|
@ -188,11 +198,14 @@ function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
|
|||
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
||||
dh = similar(h)
|
||||
dc = c == nothing ? nothing : similar(c)
|
||||
cudnnRNNBackwardData(rnn, 1,
|
||||
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
|
||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
|
||||
workspace[], reserve)
|
||||
desc = FilterDesc(T, (1, 1, length(rnn.params)))
|
||||
@nogc begin
|
||||
cudnnRNNBackwardData(rnn, 1,
|
||||
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
|
||||
desc, rnn.params,
|
||||
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
|
||||
workspace[], reserve)
|
||||
end
|
||||
return c == nothing ? (dx, dh) : (dx, dh, dc)
|
||||
end
|
||||
|
||||
|
@ -215,10 +228,12 @@ end
|
|||
|
||||
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
|
||||
dw = zero(rnn.params)
|
||||
cudnnRNNBackwardWeights(rnn, 1,
|
||||
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
||||
FilterDesc(T, (1, 1, length(dw))), dw,
|
||||
workspace[], reserve)
|
||||
@nogc begin
|
||||
cudnnRNNBackwardWeights(rnn, 1,
|
||||
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
||||
FilterDesc(T, (1, 1, length(dw))), dw,
|
||||
workspace[], reserve)
|
||||
end
|
||||
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in New Issue