2017-11-09 14:53:26 +00:00
|
|
|
|
import Base: *
|
|
|
|
|
|
2017-09-06 22:58:55 +00:00
|
|
|
|
struct OneHotVector <: AbstractVector{Bool}
|
|
|
|
|
ix::UInt32
|
|
|
|
|
of::UInt32
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
Base.size(xs::OneHotVector) = (Int64(xs.of),)
|
|
|
|
|
|
|
|
|
|
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
|
|
|
|
|
2019-04-04 13:49:47 +00:00
|
|
|
|
Base.getindex(xs::OneHotVector, ::Colon) = OneHotVector(xs.ix, xs.of)
|
2019-02-09 17:02:02 +00:00
|
|
|
|
|
2017-11-09 14:53:26 +00:00
|
|
|
|
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
|
2017-09-06 22:58:55 +00:00
|
|
|
|
|
2017-09-27 20:58:34 +00:00
|
|
|
|
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
|
2017-10-02 19:50:11 +00:00
|
|
|
|
height::Int
|
2017-09-27 20:58:34 +00:00
|
|
|
|
data::A
|
2017-09-06 22:58:55 +00:00
|
|
|
|
end
|
|
|
|
|
|
2017-10-02 19:50:11 +00:00
|
|
|
|
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
2017-09-06 22:58:55 +00:00
|
|
|
|
|
2019-02-28 03:47:18 +00:00
|
|
|
|
Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i]
|
2017-12-15 16:17:39 +00:00
|
|
|
|
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
|
|
|
|
|
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
2019-04-04 13:58:40 +00:00
|
|
|
|
Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data))
|
2017-09-06 22:58:55 +00:00
|
|
|
|
|
2019-02-09 17:02:02 +00:00
|
|
|
|
Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)
|
|
|
|
|
|
2017-11-09 14:53:26 +00:00
|
|
|
|
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
2017-09-06 22:58:55 +00:00
|
|
|
|
|
2017-10-15 22:44:16 +00:00
|
|
|
|
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
|
2017-09-06 22:58:55 +00:00
|
|
|
|
|
2017-10-15 22:44:40 +00:00
|
|
|
|
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
|
|
|
|
|
|
2018-11-14 15:34:45 +00:00
|
|
|
|
import Adapt: adapt, adapt_structure
|
2017-10-04 17:55:56 +00:00
|
|
|
|
|
2018-11-14 15:34:45 +00:00
|
|
|
|
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
2017-10-04 17:55:56 +00:00
|
|
|
|
|
2018-08-03 11:54:24 +00:00
|
|
|
|
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
2018-08-20 12:08:04 +00:00
|
|
|
|
import .CuArrays: CuArray, cudaconvert
|
|
|
|
|
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
|
|
|
|
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
2017-10-02 19:50:11 +00:00
|
|
|
|
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
2017-09-27 20:58:34 +00:00
|
|
|
|
end
|
|
|
|
|
|
2019-04-26 10:05:03 +00:00
|
|
|
|
"""
|
|
|
|
|
onehot(l, labels[, unk])
|
|
|
|
|
|
|
|
|
|
Create an [`OneHotVector`](@ref) wtih `l`-th element be `true` based on possible `labels` set.
|
2019-04-26 10:09:14 +00:00
|
|
|
|
If `unk` is given, it retruns `onehot(unk, labels)` if the input label `l` is not find in `labels`; otherwise
|
2019-04-26 10:05:03 +00:00
|
|
|
|
it will error.
|
|
|
|
|
|
|
|
|
|
## Examples
|
|
|
|
|
|
|
|
|
|
```jldoctest
|
|
|
|
|
julia> onehot(:b, [:a, :b, :c])
|
|
|
|
|
3-element Flux.OneHotVector:
|
|
|
|
|
false
|
|
|
|
|
true
|
|
|
|
|
false
|
|
|
|
|
|
|
|
|
|
julia> onehot(:c, [:a, :b, :c])
|
|
|
|
|
3-element Flux.OneHotVector:
|
|
|
|
|
false
|
|
|
|
|
false
|
|
|
|
|
true
|
|
|
|
|
```
|
|
|
|
|
"""
|
2017-10-16 23:07:58 +00:00
|
|
|
|
function onehot(l, labels)
|
2018-07-18 13:39:20 +00:00
|
|
|
|
i = something(findfirst(isequal(l), labels), 0)
|
2017-10-16 23:07:58 +00:00
|
|
|
|
i > 0 || error("Value $l is not in labels")
|
|
|
|
|
OneHotVector(i, length(labels))
|
|
|
|
|
end
|
|
|
|
|
|
2017-11-29 16:45:50 +00:00
|
|
|
|
function onehot(l, labels, unk)
|
2018-07-18 13:39:20 +00:00
|
|
|
|
i = something(findfirst(isequal(l), labels), 0)
|
2017-11-29 16:45:50 +00:00
|
|
|
|
i > 0 || return onehot(unk, labels)
|
|
|
|
|
OneHotVector(i, length(labels))
|
|
|
|
|
end
|
|
|
|
|
|
2019-04-26 10:05:03 +00:00
|
|
|
|
"""
|
|
|
|
|
onehotbatch(ls, labels[, unk...])
|
|
|
|
|
|
|
|
|
|
Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `labels` set, returns the
|
|
|
|
|
`onehot(unk, labels)` if given labels `ls` is not found in set `labels`.
|
|
|
|
|
|
|
|
|
|
## Examples
|
|
|
|
|
|
|
|
|
|
```jldoctest
|
|
|
|
|
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
|
|
|
|
|
3×3 Flux.OneHotMatrix:
|
|
|
|
|
false true false
|
|
|
|
|
true false true
|
|
|
|
|
false false false
|
|
|
|
|
|
|
|
|
|
```
|
|
|
|
|
"""
|
2017-11-29 16:45:50 +00:00
|
|
|
|
onehotbatch(ls, labels, unk...) =
|
|
|
|
|
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
2017-09-06 22:58:55 +00:00
|
|
|
|
|
2019-02-09 17:02:02 +00:00
|
|
|
|
Base.argmax(xs::OneHotVector) = xs.ix
|
|
|
|
|
|
2019-04-26 10:05:03 +00:00
|
|
|
|
"""
|
|
|
|
|
onecold(y[, labels = 1:length(y)])
|
|
|
|
|
|
|
|
|
|
Inverse operations of [`onehot`](@ref).
|
|
|
|
|
|
|
|
|
|
## Examples
|
|
|
|
|
|
|
|
|
|
```jldoctest
|
|
|
|
|
julia> onecold([true, false, false], [:a, :b, :c])
|
|
|
|
|
:a
|
|
|
|
|
|
|
|
|
|
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
|
|
|
|
|
:c
|
|
|
|
|
```
|
|
|
|
|
"""
|
2018-08-23 14:44:28 +00:00
|
|
|
|
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
2018-08-21 03:29:57 +00:00
|
|
|
|
|
2018-08-23 12:47:43 +00:00
|
|
|
|
onecold(y::AbstractMatrix, labels...) =
|
|
|
|
|
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
|
|
|
|
|
|
2019-02-28 03:47:18 +00:00
|
|
|
|
onecold(y::OneHotMatrix, labels...) =
|
|
|
|
|
mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0)
|
2019-02-09 17:02:02 +00:00
|
|
|
|
|
2019-03-08 12:06:09 +00:00
|
|
|
|
# TODO probably still want this as a custom adjoint Zygote
|
|
|
|
|
# onecold(x::TrackedVector, l...) = onecold(data(x), l...)
|
|
|
|
|
# onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)
|