remove importing CuMatrix

This commit is contained in:
Yao Lu 2020-05-09 19:13:52 +08:00
parent 30648910c8
commit 63cb70dd23
1 changed files with 1 additions and 1 deletions

View File

@ -37,7 +37,7 @@ import Adapt: adapt, adapt_structure
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
import .CuArrays: CuArray, CuMatrix, CuArrayStyle, cudaconvert
import .CuArrays: CuArray, CuArrayStyle, cudaconvert
import Base.Broadcast: BroadcastStyle, ArrayStyle
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}()
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))