Merge #837
837: Use `CuArrays.ones` instead `cuones` which is deprecated r=dhairyagandhi96 a=mimadrid I Co-authored-by: Miguel Madrid Mencía <miguel.madrid.mencia@gmail.com>
This commit is contained in:
commit
aab3c4e052
@ -130,8 +130,8 @@ end
|
|||||||
# TODO: can we just manipulate strides here?
|
# TODO: can we just manipulate strides here?
|
||||||
# TODO: should use repmat, but this isn't implemented.
|
# TODO: should use repmat, but this isn't implemented.
|
||||||
hBatch(x::AbstractVector, h::CuVector) = h
|
hBatch(x::AbstractVector, h::CuVector) = h
|
||||||
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
|
hBatch(x::AbstractMatrix, h::CuVector) = h .* CuArrays.ones(1, size(x, 2))
|
||||||
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
|
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* CuArrays.ones(1, size(h,2) == 1 ? size(x,2) : 1)
|
||||||
|
|
||||||
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
|
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
|
||||||
h = hBatch(x, h_)
|
h = hBatch(x, h_)
|
||||||
|
Loading…
Reference in New Issue
Block a user