onehotmatrix cuda support
This commit is contained in:
parent
a60a754d68
commit
a32ae4914c
@ -22,7 +22,10 @@ Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)]
|
||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
|
||||
|
||||
@require CuArrays begin
|
||||
import CuArrays: CuArray, cudaconvert
|
||||
CuArrays.cu(xs::OneHotMatrix) = OneHotMatrix(CuArrays.cu(xs.data))
|
||||
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(cudaconvert(x.data))
|
||||
end
|
||||
|
||||
onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels))
|
||||
|
Loading…
Reference in New Issue
Block a user