Flux.jl/src/onehot.jl

69 lines
2.1 KiB
Julia
Raw Normal View History

2017-11-09 14:53:26 +00:00
import Base: *
2017-09-06 22:58:55 +00:00
struct OneHotVector <: AbstractVector{Bool}
ix::UInt32
of::UInt32
end
Base.size(xs::OneHotVector) = (Int64(xs.of),)
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
2017-11-09 14:53:26 +00:00
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
2017-09-06 22:58:55 +00:00
2017-09-27 20:58:34 +00:00
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
2017-10-02 19:50:11 +00:00
height::Int
2017-09-27 20:58:34 +00:00
data::A
2017-09-06 22:58:55 +00:00
end
2017-10-02 19:50:11 +00:00
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
2017-09-06 22:58:55 +00:00
2017-12-15 16:17:39 +00:00
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
2017-09-06 22:58:55 +00:00
2017-11-09 14:53:26 +00:00
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
2017-09-06 22:58:55 +00:00
2017-10-15 22:44:16 +00:00
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
2017-09-06 22:58:55 +00:00
2017-10-15 22:44:40 +00:00
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
2018-01-08 16:31:23 +00:00
import Adapt.adapt
2017-10-04 17:55:56 +00:00
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
2018-08-03 11:54:24 +00:00
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
2018-08-20 12:08:04 +00:00
import .CuArrays: CuArray, cudaconvert
import Base.Broadcast: BroadcastStyle, ArrayStyle
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
2017-10-02 19:50:11 +00:00
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
2017-09-27 20:58:34 +00:00
end
2017-10-16 23:07:58 +00:00
function onehot(l, labels)
2018-07-18 13:39:20 +00:00
i = something(findfirst(isequal(l), labels), 0)
2017-10-16 23:07:58 +00:00
i > 0 || error("Value $l is not in labels")
OneHotVector(i, length(labels))
end
2017-11-29 16:45:50 +00:00
function onehot(l, labels, unk)
2018-07-18 13:39:20 +00:00
i = something(findfirst(isequal(l), labels), 0)
2017-11-29 16:45:50 +00:00
i > 0 || return onehot(unk, labels)
OneHotVector(i, length(labels))
end
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
2017-09-06 22:58:55 +00:00
2018-08-21 03:29:57 +00:00
import Base:argmax
2017-09-06 22:58:55 +00:00
2018-08-21 03:29:57 +00:00
argmax(y::AbstractVector, labels) =
labels[something(findfirst(isequal(maximum(y)), y), 0)]
argmax(y::AbstractMatrix, labels) =
dropdims(mapslices(y -> argmax(y, labels), y, dims=1), dims=1)
2017-11-09 14:53:26 +00:00
# Ambiguity hack
2018-02-14 22:34:11 +00:00
a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b)
a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b)