nullable c refactor
This commit is contained in:
parent
07e1b1e0a9
commit
f866fbe575
@ -130,6 +130,11 @@ function cudnnRNNForward(::Type{T}, rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd,
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
hDesc(h::Void) = C_NULL, C_NULL
|
||||||
|
function hDesc(h::CuArray)
|
||||||
|
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
||||||
|
end
|
||||||
|
|
||||||
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; train = false) where T
|
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; train = false) where T
|
||||||
@assert size(x, 1) == rnn.input
|
@assert size(x, 1) == rnn.input
|
||||||
@assert size(h, 1) == rnn.hidden
|
@assert size(h, 1) == rnn.hidden
|
||||||
@ -140,29 +145,20 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; tra
|
|||||||
ydesc = [TensorDesc(T, (1, size(y, 1), size(y, 2)))]
|
ydesc = [TensorDesc(T, (1, size(y, 1), size(y, 2)))]
|
||||||
workspace = CuVector{UInt8}(rnnWorkspaceSize(rnn, seqLength, xdesc)) # TODO: reuse this
|
workspace = CuVector{UInt8}(rnnWorkspaceSize(rnn, seqLength, xdesc)) # TODO: reuse this
|
||||||
reserve = train ? getreserve(rnn, seqLength, xdesc) : rnn.reserve
|
reserve = train ? getreserve(rnn, seqLength, xdesc) : rnn.reserve
|
||||||
if c ≠ nothing
|
cy = c == nothing ? c : similar(c)
|
||||||
@assert size(c, 1) == rnn.hidden
|
|
||||||
@assert size(c, 2) == size(h, 2)
|
|
||||||
cptr = c
|
|
||||||
cdesc = TensorDesc(T, (size(c, 1), size(c, 2), 1))
|
|
||||||
cout = similar(c)
|
|
||||||
coutdesc = TensorDesc(T, (size(cout, 1), size(cout, 2), 1))
|
|
||||||
else
|
|
||||||
cptr = cdesc = cout = coutdesc = C_NULL
|
|
||||||
end
|
|
||||||
cudnnRNNForward(T, rnn, seqLength,
|
cudnnRNNForward(T, rnn, seqLength,
|
||||||
xdesc, x,
|
xdesc, x,
|
||||||
TensorDesc(T, (size(h, 1), size(h, 2), 1)), h,
|
hDesc(h)...,
|
||||||
cdesc, cptr,
|
hDesc(c)...,
|
||||||
TensorDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
TensorDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||||
ydesc, y,
|
ydesc, y,
|
||||||
C_NULL, C_NULL, # hout
|
C_NULL, C_NULL, # hout
|
||||||
coutdesc, cout,
|
hDesc(cy)...,
|
||||||
workspace, reserve, train = train)
|
workspace, reserve, train = train)
|
||||||
if c == nothing
|
if c == nothing
|
||||||
return y, y
|
return y, y
|
||||||
else
|
else
|
||||||
return y, y, cout
|
return y, y, cy
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user