speedup matmul of CuMatrix and OneHotMatrix

This commit is contained in:
Yao Lu 2020-04-20 20:01:17 +08:00
parent 9237cdaf5b
commit eb6898ea19
1 changed files with 5 additions and 1 deletions

View File

@ -37,10 +37,14 @@ import Adapt: adapt, adapt_structure
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
import .CuArrays: CuArray, CuArrayStyle, cudaconvert
import .CuArrays: CuArray, CuMatrix, CuArrayStyle, cudaconvert
import Base.Broadcast: BroadcastStyle, ArrayStyle
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}()
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
function Base.:(*)(A::CuMatrix, B::OneHotMatrix{CuArray{OneHotVector,1}})
I = CuArray{UInt32, 1}(B.data.buf, 2 .* B.data.dims, offset = B.data.offset)[1:2:end]
A[:, Array(I)]
end
"""
onehot(l, labels[, unk])