diff --git a/src/onehot.jl b/src/onehot.jl index 4b7e5e36..9d5394ef 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -41,6 +41,10 @@ import .CuArrays: CuArray, CuArrayStyle, cudaconvert import Base.Broadcast: BroadcastStyle, ArrayStyle BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}() cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) +function Base.:(*)(A::CuArrays.CuMatrix, B::OneHotMatrix{CuArrays.CuArray{OneHotVector,1}}) + I = CuArrays.CuArray{UInt32, 1}(B.data.buf, 2 .* B.data.dims, offset = B.data.offset)[1:2:end] + A[:, Array(I)] +end """ onehot(l, labels[, unk])