From 770f601897aa2581a7c915234a7389b6645ef55d Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Mon, 28 Jan 2019 20:23:50 -0500 Subject: [PATCH] Some memory improvements to OneHotMatrix 1. Parameterize OneHotVector on Integer type, to avoid using more memory than required for vectors of them. 2. Switch OneHotMatrix from storing a vector of OneHotVectors to only storing the data and the size of the vector (reconstructing the vector locally), thus saving half the memory required and eliminating a transpose operation for matmul with OneHotMatrix on TPU. --- src/onehot.jl | 54 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 5d902c77..4eafbcb6 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -1,32 +1,44 @@ import Base: * -struct OneHotVector <: AbstractVector{Bool} - ix::UInt32 - of::UInt32 +struct OneHotVector{T <: Integer} <: AbstractVector{Bool} + ix::T + of::T end -Base.size(xs::OneHotVector) = (Int64(xs.of),) +Base.size(xs::OneHotVector) = (Int(xs.of),) Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix A::AbstractMatrix * b::OneHotVector = A[:, b.ix] -struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} +""" + A matrix of one-hot column vectors +""" +struct OneHotMatrix{A<:AbstractVector{<:Integer}} <: AbstractMatrix{Bool} height::Int data::A end -Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) +function OneHotMatrix(xs::Vector{<:OneHotVector}) + height = length(xs[1]) + OneHotMatrix(height, map(xs) do x + length(x) == height || error("All one hot vectors must be the same length") + x.ix + end) +end -Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i] -Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] + +Base.size(xs::OneHotMatrix) = (xs.height, length(xs.data)) + +Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = OneHotVector(xs.data[i], xs.height) +Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs[:, j][i] Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) -A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] +A::AbstractMatrix * B::OneHotMatrix = A[:, B.data] -Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) +Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...]) -batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs) +batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(xs) import Adapt.adapt @@ -39,20 +51,22 @@ adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) end -function onehot(l, labels) - i = something(findfirst(isequal(l), labels), 0) - i > 0 || error("Value $l is not in labels") - OneHotVector(i, length(labels)) +function onehotidx(l, labels) + i = findfirst(isequal(l), labels) + i !== nothing || error("Value $(repr(l; context=:limited=>true)) is not in labels") + i end -function onehot(l, labels, unk) - i = something(findfirst(isequal(l), labels), 0) - i > 0 || return onehot(unk, labels) - OneHotVector(i, length(labels)) +function onehotidx(l, labels, unk) + i = findfirst(isequal(l), labels) + i !== nothing || return onehotidx(unk, labels) + i end +onehot(l, labels, unk...) = OneHotVector(onhotidx(l, labels, unk...), length(labels)) + onehotbatch(ls, labels, unk...) = - OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls]) + OneHotMatrix(length(labels), [onehotidx(l, labels, unk...) for l in ls]) onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]