Use CuArrays.ones
instead cuones
which is deprecated
This commit is contained in:
parent
7c111e7cde
commit
14affbc91b
@ -130,8 +130,8 @@ end
|
|||||||
# TODO: can we just manipulate strides here?
|
# TODO: can we just manipulate strides here?
|
||||||
# TODO: should use repmat, but this isn't implemented.
|
# TODO: should use repmat, but this isn't implemented.
|
||||||
hBatch(x::AbstractVector, h::CuVector) = h
|
hBatch(x::AbstractVector, h::CuVector) = h
|
||||||
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
|
hBatch(x::AbstractMatrix, h::CuVector) = h .* CuArrays.ones(1, size(x, 2))
|
||||||
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
|
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* CuArrays.ones(1, size(h,2) == 1 ? size(x,2) : 1)
|
||||||
|
|
||||||
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
|
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
|
||||||
h = hBatch(x, h_)
|
h = hBatch(x, h_)
|
||||||
|
Loading…
Reference in New Issue
Block a user