Some memory improvements to OneHotMatrix

1. Parameterize OneHotVector on Integer type, to avoid using more memory
   than required for vectors of them.
2. Switch OneHotMatrix from storing a vector of OneHotVectors to only storing
   the data and the size of the vector (reconstructing the vector locally), thus
   saving half the memory required and eliminating a transpose operation for
   matmul with OneHotMatrix on TPU.
This commit is contained in:
Keno Fischer 2019-01-28 20:23:50 -05:00
parent 8386a49bf9
commit 45c7ab8e6d

View File

@ -1,32 +1,44 @@
import Base: * import Base: *
struct OneHotVector <: AbstractVector{Bool} struct OneHotVector{T <: Integer} <: AbstractVector{Bool}
ix::UInt32 ix::T
of::UInt32 of::T
end end
Base.size(xs::OneHotVector) = (Int64(xs.of),) Base.size(xs::OneHotVector) = (Int(xs.of),)
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
A::AbstractMatrix * b::OneHotVector = A[:, b.ix] A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} """
A matrix of one-hot column vectors
"""
struct OneHotMatrix{A<:AbstractVector{<:Integer}} <: AbstractMatrix{Bool}
height::Int height::Int
data::A data::A
end end
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) function OneHotMatrix(xs::Vector{<:OneHotVector})
height = length(xs[1])
OneHotMatrix(height, map(xs) do x
length(x) == height || error("All one hot vectors must be the same length")
x.ix
end)
end
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] Base.size(xs::OneHotMatrix) = (xs.height, length(xs.data))
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = OneHotVector(xs.data[i], xs.height)
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs[:, j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] A::AbstractMatrix * B::OneHotMatrix = A[:, B.data]
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs) batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(xs)
import Adapt: adapt, adapt_structure import Adapt: adapt, adapt_structure
@ -39,20 +51,22 @@ adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
end end
function onehot(l, labels) function onehotidx(l, labels)
i = something(findfirst(isequal(l), labels), 0) i = findfirst(isequal(l), labels)
i > 0 || error("Value $l is not in labels") i !== nothing || error("Value $(repr(l; context=:limited=>true)) is not in labels")
OneHotVector(i, length(labels)) i
end end
function onehot(l, labels, unk) function onehotidx(l, labels, unk)
i = something(findfirst(isequal(l), labels), 0) i = findfirst(isequal(l), labels)
i > 0 || return onehot(unk, labels) i !== nothing || return onehotidx(unk, labels)
OneHotVector(i, length(labels)) i
end end
onehot(l, labels, unk...) = OneHotVector(onhotidx(l, labels, unk...), length(labels))
onehotbatch(ls, labels, unk...) = onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls]) OneHotMatrix(length(labels), [onehotidx(l, labels, unk...) for l in ls])
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)] onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]