Compare commits
1 Commits
master
...
kf/onehotm
Author | SHA1 | Date |
---|---|---|
![]() |
45c7ab8e6d |
|
@ -1,32 +1,44 @@
|
|||
import Base: *
|
||||
|
||||
struct OneHotVector <: AbstractVector{Bool}
|
||||
ix::UInt32
|
||||
of::UInt32
|
||||
struct OneHotVector{T <: Integer} <: AbstractVector{Bool}
|
||||
ix::T
|
||||
of::T
|
||||
end
|
||||
|
||||
Base.size(xs::OneHotVector) = (Int64(xs.of),)
|
||||
Base.size(xs::OneHotVector) = (Int(xs.of),)
|
||||
|
||||
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
||||
|
||||
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
|
||||
|
||||
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
|
||||
"""
|
||||
A matrix of one-hot column vectors
|
||||
"""
|
||||
struct OneHotMatrix{A<:AbstractVector{<:Integer}} <: AbstractMatrix{Bool}
|
||||
height::Int
|
||||
data::A
|
||||
end
|
||||
|
||||
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
||||
function OneHotMatrix(xs::Vector{<:OneHotVector})
|
||||
height = length(xs[1])
|
||||
OneHotMatrix(height, map(xs) do x
|
||||
length(x) == height || error("All one hot vectors must be the same length")
|
||||
x.ix
|
||||
end)
|
||||
end
|
||||
|
||||
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
|
||||
|
||||
Base.size(xs::OneHotMatrix) = (xs.height, length(xs.data))
|
||||
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = OneHotVector(xs.data[i], xs.height)
|
||||
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs[:, j][i]
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
||||
|
||||
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
||||
A::AbstractMatrix * B::OneHotMatrix = A[:, B.data]
|
||||
|
||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
|
||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
|
||||
|
||||
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
|
||||
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(xs)
|
||||
|
||||
import Adapt: adapt, adapt_structure
|
||||
|
||||
|
@ -39,20 +51,22 @@ adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)
|
|||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||
end
|
||||
|
||||
function onehot(l, labels)
|
||||
i = something(findfirst(isequal(l), labels), 0)
|
||||
i > 0 || error("Value $l is not in labels")
|
||||
OneHotVector(i, length(labels))
|
||||
function onehotidx(l, labels)
|
||||
i = findfirst(isequal(l), labels)
|
||||
i !== nothing || error("Value $(repr(l; context=:limited=>true)) is not in labels")
|
||||
i
|
||||
end
|
||||
|
||||
function onehot(l, labels, unk)
|
||||
i = something(findfirst(isequal(l), labels), 0)
|
||||
i > 0 || return onehot(unk, labels)
|
||||
OneHotVector(i, length(labels))
|
||||
function onehotidx(l, labels, unk)
|
||||
i = findfirst(isequal(l), labels)
|
||||
i !== nothing || return onehotidx(unk, labels)
|
||||
i
|
||||
end
|
||||
|
||||
onehot(l, labels, unk...) = OneHotVector(onhotidx(l, labels, unk...), length(labels))
|
||||
|
||||
onehotbatch(ls, labels, unk...) =
|
||||
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
||||
OneHotMatrix(length(labels), [onehotidx(l, labels, unk...) for l in ls])
|
||||
|
||||
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
||||
|
||||
|
|
Loading…
Reference in New Issue