2017-11-09 14:53:26 +00:00
|
|
|
import Base: *
|
|
|
|
|
2017-09-06 22:58:55 +00:00
|
|
|
struct OneHotVector <: AbstractVector{Bool}
|
|
|
|
ix::UInt32
|
|
|
|
of::UInt32
|
|
|
|
end
|
|
|
|
|
|
|
|
Base.size(xs::OneHotVector) = (Int64(xs.of),)
|
|
|
|
|
|
|
|
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
|
|
|
|
2017-11-09 14:53:26 +00:00
|
|
|
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
|
2017-09-06 22:58:55 +00:00
|
|
|
|
2017-09-27 20:58:34 +00:00
|
|
|
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
|
2017-10-02 19:50:11 +00:00
|
|
|
height::Int
|
2017-09-27 20:58:34 +00:00
|
|
|
data::A
|
2017-09-06 22:58:55 +00:00
|
|
|
end
|
|
|
|
|
2017-10-02 19:50:11 +00:00
|
|
|
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
2017-09-06 22:58:55 +00:00
|
|
|
|
2017-12-15 16:17:39 +00:00
|
|
|
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
|
|
|
|
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
|
|
|
|
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
2017-09-06 22:58:55 +00:00
|
|
|
|
2017-11-09 14:53:26 +00:00
|
|
|
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
2017-09-06 22:58:55 +00:00
|
|
|
|
2017-10-15 22:44:16 +00:00
|
|
|
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
|
2017-09-06 22:58:55 +00:00
|
|
|
|
2017-10-15 22:44:40 +00:00
|
|
|
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
|
|
|
|
|
2017-10-04 17:55:56 +00:00
|
|
|
import NNlib.adapt
|
|
|
|
|
|
|
|
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
|
|
|
|
2017-09-27 20:58:34 +00:00
|
|
|
@require CuArrays begin
|
2017-09-27 21:51:00 +00:00
|
|
|
import CuArrays: CuArray, cudaconvert
|
|
|
|
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
|
2017-10-02 19:50:11 +00:00
|
|
|
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
2017-09-27 20:58:34 +00:00
|
|
|
end
|
|
|
|
|
2017-10-16 23:07:58 +00:00
|
|
|
function onehot(l, labels)
|
|
|
|
i = findfirst(labels, l)
|
|
|
|
i > 0 || error("Value $l is not in labels")
|
|
|
|
OneHotVector(i, length(labels))
|
|
|
|
end
|
|
|
|
|
2017-11-29 16:45:50 +00:00
|
|
|
function onehot(l, labels, unk)
|
|
|
|
i = findfirst(labels, l)
|
|
|
|
i > 0 || return onehot(unk, labels)
|
|
|
|
OneHotVector(i, length(labels))
|
|
|
|
end
|
|
|
|
|
|
|
|
onehotbatch(ls, labels, unk...) =
|
|
|
|
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
2017-09-06 22:58:55 +00:00
|
|
|
|
2017-09-11 12:40:11 +00:00
|
|
|
argmax(y::AbstractVector, labels = 1:length(y)) =
|
2017-09-06 22:58:55 +00:00
|
|
|
labels[findfirst(y, maximum(y))]
|
|
|
|
|
2017-09-11 12:40:11 +00:00
|
|
|
argmax(y::AbstractMatrix, l...) =
|
|
|
|
squeeze(mapslices(y -> argmax(y, l...), y, 1), 1)
|
2017-11-09 14:53:26 +00:00
|
|
|
|
|
|
|
# Ambiguity hack
|
|
|
|
|
|
|
|
a::TrackedMatrix * b::OneHotVector = TrackedArray(Tracker.Call(*, a, b))
|
|
|
|
a::TrackedMatrix * b::OneHotMatrix = TrackedArray(Tracker.Call(*, a, b))
|