Adapt to CuArrays ArrayStyle changes.
This commit is contained in:
parent
7e58766467
commit
4ed7d984db
|
@ -37,9 +37,9 @@ import Adapt: adapt, adapt_structure
|
|||
|
||||
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||
|
||||
import .CuArrays: CuArray, cudaconvert
|
||||
import .CuArrays: CuArray, CuArrayStyle, cudaconvert
|
||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}()
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue