nullable c refactor

This commit is contained in:
Mike J Innes 2018-02-06 13:29:57 +00:00
parent 07e1b1e0a9
commit f866fbe575

View File

@ -130,6 +130,11 @@ function cudnnRNNForward(::Type{T}, rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd,
end end
end end
hDesc(h::Void) = C_NULL, C_NULL
function hDesc(h::CuArray)
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
end
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; train = false) where T function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; train = false) where T
@assert size(x, 1) == rnn.input @assert size(x, 1) == rnn.input
@assert size(h, 1) == rnn.hidden @assert size(h, 1) == rnn.hidden
@ -140,29 +145,20 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; tra
ydesc = [TensorDesc(T, (1, size(y, 1), size(y, 2)))] ydesc = [TensorDesc(T, (1, size(y, 1), size(y, 2)))]
workspace = CuVector{UInt8}(rnnWorkspaceSize(rnn, seqLength, xdesc)) # TODO: reuse this workspace = CuVector{UInt8}(rnnWorkspaceSize(rnn, seqLength, xdesc)) # TODO: reuse this
reserve = train ? getreserve(rnn, seqLength, xdesc) : rnn.reserve reserve = train ? getreserve(rnn, seqLength, xdesc) : rnn.reserve
if c nothing cy = c == nothing ? c : similar(c)
@assert size(c, 1) == rnn.hidden
@assert size(c, 2) == size(h, 2)
cptr = c
cdesc = TensorDesc(T, (size(c, 1), size(c, 2), 1))
cout = similar(c)
coutdesc = TensorDesc(T, (size(cout, 1), size(cout, 2), 1))
else
cptr = cdesc = cout = coutdesc = C_NULL
end
cudnnRNNForward(T, rnn, seqLength, cudnnRNNForward(T, rnn, seqLength,
xdesc, x, xdesc, x,
TensorDesc(T, (size(h, 1), size(h, 2), 1)), h, hDesc(h)...,
cdesc, cptr, hDesc(c)...,
TensorDesc(T, (1, 1, length(rnn.params))), rnn.params, TensorDesc(T, (1, 1, length(rnn.params))), rnn.params,
ydesc, y, ydesc, y,
C_NULL, C_NULL, # hout C_NULL, C_NULL, # hout
coutdesc, cout, hDesc(cy)...,
workspace, reserve, train = train) workspace, reserve, train = train)
if c == nothing if c == nothing
return y, y return y, y
else else
return y, y, cout return y, y, cy
end end
end end