OneHotMatrix WIP
This commit is contained in:
parent
be0133fb67
commit
58e299eafb
@ -14,10 +14,12 @@ A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
|
|||||||
"""
|
"""
|
||||||
A matrix of one-hot column vectors
|
A matrix of one-hot column vectors
|
||||||
"""
|
"""
|
||||||
struct OneHotMatrix{A<:AbstractVector{<:Integer}} <: AbstractMatrix{Bool}
|
struct OneHotMatrix{height, A<:AbstractVector{<:Integer}} <: AbstractMatrix{Bool}
|
||||||
height::Int
|
|
||||||
data::A
|
data::A
|
||||||
end
|
end
|
||||||
|
Flux.OneHotMatrix{height}(data::AbstractVector{<:Integer}) where {height} =
|
||||||
|
OneHotMatrix{height, typeof(data)}(data)
|
||||||
|
Flux.OneHotMatrix(height, data) = OneHotMatrix{height}(data)
|
||||||
|
|
||||||
function OneHotMatrix(xs::Vector{<:OneHotVector})
|
function OneHotMatrix(xs::Vector{<:OneHotVector})
|
||||||
height = length(xs[1])
|
height = length(xs[1])
|
||||||
@ -28,11 +30,11 @@ function OneHotMatrix(xs::Vector{<:OneHotVector})
|
|||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
Base.size(xs::OneHotMatrix) = (xs.height, length(xs.data))
|
Base.size(xs::OneHotMatrix{height}) where {height} = (height, length(xs.data))
|
||||||
|
|
||||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = OneHotVector(xs.data[i], xs.height)
|
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = OneHotVector(xs.data[i], size(xs)[1])
|
||||||
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs[:, j][i]
|
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs[:, j][i]
|
||||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(size(xs)[1], xs.data[i])
|
||||||
|
|
||||||
A::AbstractMatrix * B::OneHotMatrix = A[:, B.data]
|
A::AbstractMatrix * B::OneHotMatrix = A[:, B.data]
|
||||||
|
|
||||||
@ -42,13 +44,13 @@ batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(xs)
|
|||||||
|
|
||||||
import Adapt.adapt
|
import Adapt.adapt
|
||||||
|
|
||||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
adapt(T, xs::OneHotMatrix) = OneHotMatrix(size(xs)[1], adapt(T, xs.data))
|
||||||
|
|
||||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||||
import .CuArrays: CuArray, cudaconvert
|
import .CuArrays: CuArray, cudaconvert
|
||||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(size(x)[1], cudaconvert(x.data))
|
||||||
end
|
end
|
||||||
|
|
||||||
function onehotidx(l, labels)
|
function onehotidx(l, labels)
|
||||||
@ -63,7 +65,7 @@ function onehotidx(l, labels, unk)
|
|||||||
i
|
i
|
||||||
end
|
end
|
||||||
|
|
||||||
onehot(l, labels, unk...) = OneHotVector(onhotidx(l, labels, unk...), length(labels))
|
onehot(l, labels, unk...) = OneHotVector(onehotidx(l, labels, unk...), length(labels))
|
||||||
|
|
||||||
onehotbatch(ls, labels, unk...) =
|
onehotbatch(ls, labels, unk...) =
|
||||||
OneHotMatrix(length(labels), [onehotidx(l, labels, unk...) for l in ls])
|
OneHotMatrix(length(labels), [onehotidx(l, labels, unk...) for l in ls])
|
||||||
|
Loading…
Reference in New Issue
Block a user