Compare commits

...

1 Commits

Author SHA1 Message Date
Mike Innes
5e51f75f45 gc nuclear option 2019-09-10 13:46:09 +01:00

View File

@ -4,6 +4,14 @@ using CuArrays.CUDNN: @check, cudnnStatus_t, cudnnTensorDescriptor_t,
import CuArrays.CUDAdrv: CuPtr, CU_NULL import CuArrays.CUDAdrv: CuPtr, CU_NULL
macro nogc(ex)
quote
st = GC.enable(false)
$(esc(ex))
GC.enable(st)
end
end
using LinearAlgebra using LinearAlgebra
const RNN_RELU = 0 # Stock RNN with ReLu activation const RNN_RELU = 0 # Stock RNN with ReLu activation
@ -153,15 +161,17 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, t
CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) : CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
nothing nothing
co = c == nothing ? c : similar(c) co = c == nothing ? c : similar(c)
cudnnRNNForward(rnn, seqLength, @nogc begin
xdesc, x, cudnnRNNForward(rnn, seqLength,
hDesc(h)..., xdesc, x,
hDesc(c)..., hDesc(h)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, hDesc(c)...,
ydesc, y, FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
hDesc(ho)..., ydesc, y,
hDesc(co)..., hDesc(ho)...,
workspace, reserve) hDesc(co)...,
workspace, reserve)
end
result = c == nothing ? (y, ho) : (y, ho, co) result = c == nothing ? (y, ho) : (y, ho, co)
return train == Val{true} ? (reserve, result) : result return train == Val{true} ? (reserve, result) : result
end end
@ -188,11 +198,14 @@ function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2)) dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
dh = similar(h) dh = similar(h)
dc = c == nothing ? nothing : similar(c) dc = c == nothing ? nothing : similar(c)
cudnnRNNBackwardData(rnn, 1, desc = FilterDesc(T, (1, 1, length(rnn.params)))
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)..., @nogc begin
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, cudnnRNNBackwardData(rnn, 1,
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)..., yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
workspace[], reserve) desc, rnn.params,
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
workspace[], reserve)
end
return c == nothing ? (dx, dh) : (dx, dh, dc) return c == nothing ? (dx, dh) : (dx, dh, dc)
end end
@ -215,10 +228,12 @@ end
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
dw = zero(rnn.params) dw = zero(rnn.params)
cudnnRNNBackwardWeights(rnn, 1, @nogc begin
xDesc(x), x, hDesc(h)..., xDesc(y), y, cudnnRNNBackwardWeights(rnn, 1,
FilterDesc(T, (1, 1, length(dw))), dw, xDesc(x), x, hDesc(h)..., xDesc(y), y,
workspace[], reserve) FilterDesc(T, (1, 1, length(dw))), dw,
workspace[], reserve)
end
return params(dw, rnn.input, rnn.hidden, ngates(rnn)) return params(dw, rnn.input, rnn.hidden, ngates(rnn))
end end