diff --git a/src/onehot.jl b/src/onehot.jl index 7a3123ec..b480d9c0 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -37,9 +37,9 @@ import Adapt: adapt, adapt_structure adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) -import .CuArrays: CuArray, cudaconvert +import .CuArrays: CuArray, CuArrayStyle, cudaconvert import Base.Broadcast: BroadcastStyle, ArrayStyle -BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}() +BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}() cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) """