remove importing CuMatrix
This commit is contained in:
parent
30648910c8
commit
63cb70dd23
@ -37,7 +37,7 @@ import Adapt: adapt, adapt_structure
|
|||||||
|
|
||||||
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||||
|
|
||||||
import .CuArrays: CuArray, CuMatrix, CuArrayStyle, cudaconvert
|
import .CuArrays: CuArray, CuArrayStyle, cudaconvert
|
||||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}()
|
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}()
|
||||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||||
|
Loading…
Reference in New Issue
Block a user