OneHotMatrix WIP

This commit is contained in:
Keno Fischer 2019-02-13 16:48:13 -05:00
parent be0133fb67
commit 58e299eafb
1 changed files with 10 additions and 8 deletions

View File

@ -14,10 +14,12 @@ A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
"""
A matrix of one-hot column vectors
"""
struct OneHotMatrix{A<:AbstractVector{<:Integer}} <: AbstractMatrix{Bool}
height::Int
struct OneHotMatrix{height, A<:AbstractVector{<:Integer}} <: AbstractMatrix{Bool}
data::A
end
Flux.OneHotMatrix{height}(data::AbstractVector{<:Integer}) where {height} =
OneHotMatrix{height, typeof(data)}(data)
Flux.OneHotMatrix(height, data) = OneHotMatrix{height}(data)
function OneHotMatrix(xs::Vector{<:OneHotVector})
height = length(xs[1])
@ -28,11 +30,11 @@ function OneHotMatrix(xs::Vector{<:OneHotVector})
end
Base.size(xs::OneHotMatrix) = (xs.height, length(xs.data))
Base.size(xs::OneHotMatrix{height}) where {height} = (height, length(xs.data))
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = OneHotVector(xs.data[i], xs.height)
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = OneHotVector(xs.data[i], size(xs)[1])
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs[:, j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(size(xs)[1], xs.data[i])
A::AbstractMatrix * B::OneHotMatrix = A[:, B.data]
@ -42,13 +44,13 @@ batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(xs)
import Adapt.adapt
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
adapt(T, xs::OneHotMatrix) = OneHotMatrix(size(xs)[1], adapt(T, xs.data))
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
import .CuArrays: CuArray, cudaconvert
import Base.Broadcast: BroadcastStyle, ArrayStyle
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(size(x)[1], cudaconvert(x.data))
end
function onehotidx(l, labels)
@ -63,7 +65,7 @@ function onehotidx(l, labels, unk)
i
end
onehot(l, labels, unk...) = OneHotVector(onhotidx(l, labels, unk...), length(labels))
onehot(l, labels, unk...) = OneHotVector(onehotidx(l, labels, unk...), length(labels))
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehotidx(l, labels, unk...) for l in ls])