transfer onehot indices back to cpu
This commit is contained in:
parent
eb6898ea19
commit
30648910c8
@ -27,7 +27,7 @@ Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy
|
|||||||
|
|
||||||
Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)
|
Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)
|
||||||
|
|
||||||
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
A::AbstractMatrix * B::OneHotMatrix = A[:, cpu(map(x->x.ix, B.data))]
|
||||||
|
|
||||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
|
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
|
||||||
|
|
||||||
@ -41,10 +41,6 @@ import .CuArrays: CuArray, CuMatrix, CuArrayStyle, cudaconvert
|
|||||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}()
|
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}()
|
||||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||||
function Base.:(*)(A::CuMatrix, B::OneHotMatrix{CuArray{OneHotVector,1}})
|
|
||||||
I = CuArray{UInt32, 1}(B.data.buf, 2 .* B.data.dims, offset = B.data.offset)[1:2:end]
|
|
||||||
A[:, Array(I)]
|
|
||||||
end
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
onehot(l, labels[, unk])
|
onehot(l, labels[, unk])
|
||||||
|
Loading…
Reference in New Issue
Block a user