speedup matmul of CuMatrix and OneHotMatrix
This commit is contained in:
parent
32e2435729
commit
427c55af92
|
@ -41,6 +41,10 @@ import .CuArrays: CuArray, CuArrayStyle, cudaconvert
|
|||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}()
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||
function Base.:(*)(A::CuArrays.CuMatrix, B::OneHotMatrix{CuArrays.CuArray{OneHotVector,1}})
|
||||
I = CuArrays.CuArray{UInt32, 1}(B.data.buf, 2 .* B.data.dims, offset = B.data.offset)[1:2:end]
|
||||
A[:, Array(I)]
|
||||
end
|
||||
|
||||
"""
|
||||
onehot(l, labels[, unk])
|
||||
|
|
Loading…
Reference in New Issue