diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index 4990599f..c60104d2 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -130,8 +130,8 @@ end # TODO: can we just manipulate strides here? # TODO: should use repmat, but this isn't implemented. hBatch(x::AbstractVector, h::CuVector) = h -hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2)) -hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1) +hBatch(x::AbstractMatrix, h::CuVector) = h .* CuArrays.ones(1, size(x, 2)) +hBatch(x::AbstractMatrix, h::CuMatrix) = h .* CuArrays.ones(1, size(h,2) == 1 ? size(x,2) : 1) function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T h = hBatch(x, h_)