Compare commits

...

1 Commits

Author SHA1 Message Date
Mike Innes 5e51f75f45 gc nuclear option 2019-09-10 13:46:09 +01:00
1 changed files with 33 additions and 18 deletions

View File

@ -4,6 +4,14 @@ using CuArrays.CUDNN: @check, cudnnStatus_t, cudnnTensorDescriptor_t,
import CuArrays.CUDAdrv: CuPtr, CU_NULL
macro nogc(ex)
quote
st = GC.enable(false)
$(esc(ex))
GC.enable(st)
end
end
using LinearAlgebra
const RNN_RELU = 0 # Stock RNN with ReLu activation
@ -153,15 +161,17 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, t
CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
nothing
co = c == nothing ? c : similar(c)
cudnnRNNForward(rnn, seqLength,
xdesc, x,
hDesc(h)...,
hDesc(c)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
ydesc, y,
hDesc(ho)...,
hDesc(co)...,
workspace, reserve)
@nogc begin
cudnnRNNForward(rnn, seqLength,
xdesc, x,
hDesc(h)...,
hDesc(c)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
ydesc, y,
hDesc(ho)...,
hDesc(co)...,
workspace, reserve)
end
result = c == nothing ? (y, ho) : (y, ho, co)
return train == Val{true} ? (reserve, result) : result
end
@ -188,11 +198,14 @@ function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
dh = similar(h)
dc = c == nothing ? nothing : similar(c)
cudnnRNNBackwardData(rnn, 1,
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
workspace[], reserve)
desc = FilterDesc(T, (1, 1, length(rnn.params)))
@nogc begin
cudnnRNNBackwardData(rnn, 1,
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
desc, rnn.params,
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
workspace[], reserve)
end
return c == nothing ? (dx, dh) : (dx, dh, dc)
end
@ -215,10 +228,12 @@ end
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
dw = zero(rnn.params)
cudnnRNNBackwardWeights(rnn, 1,
xDesc(x), x, hDesc(h)..., xDesc(y), y,
FilterDesc(T, (1, 1, length(dw))), dw,
workspace[], reserve)
@nogc begin
cudnnRNNBackwardWeights(rnn, 1,
xDesc(x), x, hDesc(h)..., xDesc(y), y,
FilterDesc(T, (1, 1, length(dw))), dw,
workspace[], reserve)
end
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
end