diff --git a/src/onehot.jl b/src/onehot.jl index 48b7ccf5..2f1eb365 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -22,7 +22,10 @@ Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)] Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...]) @require CuArrays begin + import CuArrays: CuArray, cudaconvert CuArrays.cu(xs::OneHotMatrix) = OneHotMatrix(CuArrays.cu(xs.data)) + Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray + cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(cudaconvert(x.data)) end onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels))