From 427c55af9287783807ad4198ecc445da56872225 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Mon, 20 Apr 2020 19:11:57 +0800 Subject: [PATCH] speedup matmul of CuMatrix and OneHotMatrix --- src/onehot.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/onehot.jl b/src/onehot.jl index 4b7e5e36..9d5394ef 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -41,6 +41,10 @@ import .CuArrays: CuArray, CuArrayStyle, cudaconvert import Base.Broadcast: BroadcastStyle, ArrayStyle BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}() cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) +function Base.:(*)(A::CuArrays.CuMatrix, B::OneHotMatrix{CuArrays.CuArray{OneHotVector,1}}) + I = CuArrays.CuArray{UInt32, 1}(B.data.buf, 2 .* B.data.dims, offset = B.data.offset)[1:2:end] + A[:, Array(I)] +end """ onehot(l, labels[, unk])