Adapt to CuArrays ArrayStyle changes.

This commit is contained in:
Tim Besard 2020-02-25 14:09:03 +01:00
parent 7e58766467
commit 4ed7d984db
1 changed files with 2 additions and 2 deletions

View File

@ -37,9 +37,9 @@ import Adapt: adapt, adapt_structure
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
import .CuArrays: CuArray, cudaconvert import .CuArrays: CuArray, CuArrayStyle, cudaconvert
import Base.Broadcast: BroadcastStyle, ArrayStyle import Base.Broadcast: BroadcastStyle, ArrayStyle
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}() BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}()
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
""" """