Compare commits

...

1 Commits

Author SHA1 Message Date
Keno Fischer 45c7ab8e6d Some memory improvements to OneHotMatrix
1. Parameterize OneHotVector on Integer type, to avoid using more memory
   than required for vectors of them.
2. Switch OneHotMatrix from storing a vector of OneHotVectors to only storing
   the data and the size of the vector (reconstructing the vector locally), thus
   saving half the memory required and eliminating a transpose operation for
   matmul with OneHotMatrix on TPU.
2019-01-28 20:26:59 -05:00
1 changed files with 34 additions and 20 deletions

View File

@ -1,32 +1,44 @@
import Base: *
struct OneHotVector <: AbstractVector{Bool}
ix::UInt32
of::UInt32
struct OneHotVector{T <: Integer} <: AbstractVector{Bool}
ix::T
of::T
end
Base.size(xs::OneHotVector) = (Int64(xs.of),)
Base.size(xs::OneHotVector) = (Int(xs.of),)
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
"""
A matrix of one-hot column vectors
"""
struct OneHotMatrix{A<:AbstractVector{<:Integer}} <: AbstractMatrix{Bool}
height::Int
data::A
end
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
function OneHotMatrix(xs::Vector{<:OneHotVector})
height = length(xs[1])
OneHotMatrix(height, map(xs) do x
length(x) == height || error("All one hot vectors must be the same length")
x.ix
end)
end
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.size(xs::OneHotMatrix) = (xs.height, length(xs.data))
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = OneHotVector(xs.data[i], xs.height)
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs[:, j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
A::AbstractMatrix * B::OneHotMatrix = A[:, B.data]
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(xs)
import Adapt: adapt, adapt_structure
@ -39,20 +51,22 @@ adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
end
function onehot(l, labels)
i = something(findfirst(isequal(l), labels), 0)
i > 0 || error("Value $l is not in labels")
OneHotVector(i, length(labels))
function onehotidx(l, labels)
i = findfirst(isequal(l), labels)
i !== nothing || error("Value $(repr(l; context=:limited=>true)) is not in labels")
i
end
function onehot(l, labels, unk)
i = something(findfirst(isequal(l), labels), 0)
i > 0 || return onehot(unk, labels)
OneHotVector(i, length(labels))
function onehotidx(l, labels, unk)
i = findfirst(isequal(l), labels)
i !== nothing || return onehotidx(unk, labels)
i
end
onehot(l, labels, unk...) = OneHotVector(onhotidx(l, labels, unk...), length(labels))
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
OneHotMatrix(length(labels), [onehotidx(l, labels, unk...) for l in ls])
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]