From bd6158d7f9ec62619db602b2e14c5e1e0546848e Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sun, 3 Feb 2019 03:57:41 +0530 Subject: [PATCH] onehotvector/matrix behaviour --- src/onehot.jl | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index cd29f14e..21524135 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -9,6 +9,8 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),) Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix +Base.getindex(xs::OneHotVector, ::Colon) = xs + A::AbstractMatrix * b::OneHotVector = A[:, b.ix] struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} @@ -22,6 +24,21 @@ Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i] Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) +Base.getindex(xs::Flux.OneHotMatrix, j::Base.UnitRange, i::Int) = xs.data[i][j] + +Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = xs + +# handle special case for when we want the entire column without allocating +function Base.getindex(xs::Flux.OneHotMatrix{T}, ot::Union{Base.Slice, Base.OneTo}, i::Int) where {T<:AbstractArray} + res = similar(xs, size(xs, 1), 1) + if length(ot) == size(xs, 1) + res = xs[:,i] + else + res = xs[1:length(ot),i] + end + res +end + A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) @@ -54,13 +71,15 @@ end onehotbatch(ls, labels, unk...) = OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls]) +Base.argmax(xs::OneHotVector) = xs.ix + onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)] onecold(y::AbstractMatrix, labels...) = dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1) function argmax(xs...) - Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax) + Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax) return onecold(xs...) end