OneHotMatrix WIP
This commit is contained in:
parent
be0133fb67
commit
58e299eafb
|
@ -14,10 +14,12 @@ A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
|
|||
"""
|
||||
A matrix of one-hot column vectors
|
||||
"""
|
||||
struct OneHotMatrix{A<:AbstractVector{<:Integer}} <: AbstractMatrix{Bool}
|
||||
height::Int
|
||||
struct OneHotMatrix{height, A<:AbstractVector{<:Integer}} <: AbstractMatrix{Bool}
|
||||
data::A
|
||||
end
|
||||
Flux.OneHotMatrix{height}(data::AbstractVector{<:Integer}) where {height} =
|
||||
OneHotMatrix{height, typeof(data)}(data)
|
||||
Flux.OneHotMatrix(height, data) = OneHotMatrix{height}(data)
|
||||
|
||||
function OneHotMatrix(xs::Vector{<:OneHotVector})
|
||||
height = length(xs[1])
|
||||
|
@ -28,11 +30,11 @@ function OneHotMatrix(xs::Vector{<:OneHotVector})
|
|||
end
|
||||
|
||||
|
||||
Base.size(xs::OneHotMatrix) = (xs.height, length(xs.data))
|
||||
Base.size(xs::OneHotMatrix{height}) where {height} = (height, length(xs.data))
|
||||
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = OneHotVector(xs.data[i], xs.height)
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = OneHotVector(xs.data[i], size(xs)[1])
|
||||
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs[:, j][i]
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(size(xs)[1], xs.data[i])
|
||||
|
||||
A::AbstractMatrix * B::OneHotMatrix = A[:, B.data]
|
||||
|
||||
|
@ -42,13 +44,13 @@ batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(xs)
|
|||
|
||||
import Adapt.adapt
|
||||
|
||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(size(xs)[1], adapt(T, xs.data))
|
||||
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||
import .CuArrays: CuArray, cudaconvert
|
||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(size(x)[1], cudaconvert(x.data))
|
||||
end
|
||||
|
||||
function onehotidx(l, labels)
|
||||
|
@ -63,7 +65,7 @@ function onehotidx(l, labels, unk)
|
|||
i
|
||||
end
|
||||
|
||||
onehot(l, labels, unk...) = OneHotVector(onhotidx(l, labels, unk...), length(labels))
|
||||
onehot(l, labels, unk...) = OneHotVector(onehotidx(l, labels, unk...), length(labels))
|
||||
|
||||
onehotbatch(ls, labels, unk...) =
|
||||
OneHotMatrix(length(labels), [onehotidx(l, labels, unk...) for l in ls])
|
||||
|
|
Loading…
Reference in New Issue