Some memory improvements to OneHotMatrix
1. Parameterize OneHotVector on Integer type, to avoid using more memory than required for vectors of them. 2. Switch OneHotMatrix from storing a vector of OneHotVectors to only storing the data and the size of the vector (reconstructing the vector locally), thus saving half the memory required and eliminating a transpose operation for matmul with OneHotMatrix on TPU.
This commit is contained in:
parent
8386a49bf9
commit
45c7ab8e6d
@ -1,32 +1,44 @@
|
|||||||
import Base: *
|
import Base: *
|
||||||
|
|
||||||
struct OneHotVector <: AbstractVector{Bool}
|
struct OneHotVector{T <: Integer} <: AbstractVector{Bool}
|
||||||
ix::UInt32
|
ix::T
|
||||||
of::UInt32
|
of::T
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.size(xs::OneHotVector) = (Int64(xs.of),)
|
Base.size(xs::OneHotVector) = (Int(xs.of),)
|
||||||
|
|
||||||
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
||||||
|
|
||||||
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
|
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
|
||||||
|
|
||||||
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
|
"""
|
||||||
|
A matrix of one-hot column vectors
|
||||||
|
"""
|
||||||
|
struct OneHotMatrix{A<:AbstractVector{<:Integer}} <: AbstractMatrix{Bool}
|
||||||
height::Int
|
height::Int
|
||||||
data::A
|
data::A
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
function OneHotMatrix(xs::Vector{<:OneHotVector})
|
||||||
|
height = length(xs[1])
|
||||||
|
OneHotMatrix(height, map(xs) do x
|
||||||
|
length(x) == height || error("All one hot vectors must be the same length")
|
||||||
|
x.ix
|
||||||
|
end)
|
||||||
|
end
|
||||||
|
|
||||||
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
|
|
||||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
|
Base.size(xs::OneHotMatrix) = (xs.height, length(xs.data))
|
||||||
|
|
||||||
|
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = OneHotVector(xs.data[i], xs.height)
|
||||||
|
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs[:, j][i]
|
||||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
||||||
|
|
||||||
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
A::AbstractMatrix * B::OneHotMatrix = A[:, B.data]
|
||||||
|
|
||||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
|
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
|
||||||
|
|
||||||
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
|
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(xs)
|
||||||
|
|
||||||
import Adapt: adapt, adapt_structure
|
import Adapt: adapt, adapt_structure
|
||||||
|
|
||||||
@ -39,20 +51,22 @@ adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)
|
|||||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||||
end
|
end
|
||||||
|
|
||||||
function onehot(l, labels)
|
function onehotidx(l, labels)
|
||||||
i = something(findfirst(isequal(l), labels), 0)
|
i = findfirst(isequal(l), labels)
|
||||||
i > 0 || error("Value $l is not in labels")
|
i !== nothing || error("Value $(repr(l; context=:limited=>true)) is not in labels")
|
||||||
OneHotVector(i, length(labels))
|
i
|
||||||
end
|
end
|
||||||
|
|
||||||
function onehot(l, labels, unk)
|
function onehotidx(l, labels, unk)
|
||||||
i = something(findfirst(isequal(l), labels), 0)
|
i = findfirst(isequal(l), labels)
|
||||||
i > 0 || return onehot(unk, labels)
|
i !== nothing || return onehotidx(unk, labels)
|
||||||
OneHotVector(i, length(labels))
|
i
|
||||||
end
|
end
|
||||||
|
|
||||||
|
onehot(l, labels, unk...) = OneHotVector(onhotidx(l, labels, unk...), length(labels))
|
||||||
|
|
||||||
onehotbatch(ls, labels, unk...) =
|
onehotbatch(ls, labels, unk...) =
|
||||||
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
OneHotMatrix(length(labels), [onehotidx(l, labels, unk...) for l in ls])
|
||||||
|
|
||||||
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user