Compare commits

...

1 Commits

Author SHA1 Message Date
Keno Fischer
45c7ab8e6d Some memory improvements to OneHotMatrix
1. Parameterize OneHotVector on Integer type, to avoid using more memory
   than required for vectors of them.
2. Switch OneHotMatrix from storing a vector of OneHotVectors to only storing
   the data and the size of the vector (reconstructing the vector locally), thus
   saving half the memory required and eliminating a transpose operation for
   matmul with OneHotMatrix on TPU.
2019-01-28 20:26:59 -05:00

View File

@ -1,32 +1,44 @@
import Base: * import Base: *
struct OneHotVector <: AbstractVector{Bool} struct OneHotVector{T <: Integer} <: AbstractVector{Bool}
ix::UInt32 ix::T
of::UInt32 of::T
end end
Base.size(xs::OneHotVector) = (Int64(xs.of),) Base.size(xs::OneHotVector) = (Int(xs.of),)
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
A::AbstractMatrix * b::OneHotVector = A[:, b.ix] A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} """
A matrix of one-hot column vectors
"""
struct OneHotMatrix{A<:AbstractVector{<:Integer}} <: AbstractMatrix{Bool}
height::Int height::Int
data::A data::A
end end
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) function OneHotMatrix(xs::Vector{<:OneHotVector})
height = length(xs[1])
OneHotMatrix(height, map(xs) do x
length(x) == height || error("All one hot vectors must be the same length")
x.ix
end)
end
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] Base.size(xs::OneHotMatrix) = (xs.height, length(xs.data))
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = OneHotVector(xs.data[i], xs.height)
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs[:, j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] A::AbstractMatrix * B::OneHotMatrix = A[:, B.data]
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs) batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(xs)
import Adapt: adapt, adapt_structure import Adapt: adapt, adapt_structure
@ -39,20 +51,22 @@ adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
end end
function onehot(l, labels) function onehotidx(l, labels)
i = something(findfirst(isequal(l), labels), 0) i = findfirst(isequal(l), labels)
i > 0 || error("Value $l is not in labels") i !== nothing || error("Value $(repr(l; context=:limited=>true)) is not in labels")
OneHotVector(i, length(labels)) i
end end
function onehot(l, labels, unk) function onehotidx(l, labels, unk)
i = something(findfirst(isequal(l), labels), 0) i = findfirst(isequal(l), labels)
i > 0 || return onehot(unk, labels) i !== nothing || return onehotidx(unk, labels)
OneHotVector(i, length(labels)) i
end end
onehot(l, labels, unk...) = OneHotVector(onhotidx(l, labels, unk...), length(labels))
onehotbatch(ls, labels, unk...) = onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls]) OneHotMatrix(length(labels), [onehotidx(l, labels, unk...) for l in ls])
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)] onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]