Some memory improvements to OneHotMatrix

1. Parameterize OneHotVector on Integer type, to avoid using more memory
   than required for vectors of them.
2. Switch OneHotMatrix from storing a vector of OneHotVectors to only storing
   the data and the size of the vector (reconstructing the vector locally), thus
   saving half the memory required and eliminating a transpose operation for
   matmul with OneHotMatrix on TPU.
This commit is contained in:
Keno Fischer 2019-01-28 20:23:50 -05:00
parent 8386a49bf9
commit 45c7ab8e6d
1 changed files with 34 additions and 20 deletions

View File

@ -1,32 +1,44 @@
import Base: *
struct OneHotVector <: AbstractVector{Bool}
ix::UInt32
of::UInt32
struct OneHotVector{T <: Integer} <: AbstractVector{Bool}
ix::T
of::T
end
Base.size(xs::OneHotVector) = (Int64(xs.of),)
Base.size(xs::OneHotVector) = (Int(xs.of),)
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
"""
A matrix of one-hot column vectors
"""
struct OneHotMatrix{A<:AbstractVector{<:Integer}} <: AbstractMatrix{Bool}
height::Int
data::A
end
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
function OneHotMatrix(xs::Vector{<:OneHotVector})
height = length(xs[1])
OneHotMatrix(height, map(xs) do x
length(x) == height || error("All one hot vectors must be the same length")
x.ix
end)
end
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.size(xs::OneHotMatrix) = (xs.height, length(xs.data))
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = OneHotVector(xs.data[i], xs.height)
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs[:, j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
A::AbstractMatrix * B::OneHotMatrix = A[:, B.data]
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(xs)
import Adapt: adapt, adapt_structure
@ -39,20 +51,22 @@ adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
end
function onehot(l, labels)
i = something(findfirst(isequal(l), labels), 0)
i > 0 || error("Value $l is not in labels")
OneHotVector(i, length(labels))
function onehotidx(l, labels)
i = findfirst(isequal(l), labels)
i !== nothing || error("Value $(repr(l; context=:limited=>true)) is not in labels")
i
end
function onehot(l, labels, unk)
i = something(findfirst(isequal(l), labels), 0)
i > 0 || return onehot(unk, labels)
OneHotVector(i, length(labels))
function onehotidx(l, labels, unk)
i = findfirst(isequal(l), labels)
i !== nothing || return onehotidx(unk, labels)
i
end
onehot(l, labels, unk...) = OneHotVector(onhotidx(l, labels, unk...), length(labels))
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
OneHotMatrix(length(labels), [onehotidx(l, labels, unk...) for l in ls])
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]