diff --git a/src/onehot.jl b/src/onehot.jl index 42945388..ef326650 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -20,23 +20,12 @@ end Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) -Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i] +Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i] Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data) -# handle special case when we want the whole column -function Base.getindex(xs::Flux.OneHotMatrix{T}, ot::Union{Base.Slice, Base.OneTo}, i::Int) where {T<:AbstractArray} - res = similar(xs, size(xs, 1), 1) - if length(ot) == size(xs, 1) - res = xs[:,i] - else - res = xs[1:length(ot),i] - end - res -end - A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) @@ -76,7 +65,8 @@ onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)] onecold(y::AbstractMatrix, labels...) = dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1) -onecold(y::OneHotMatrix, labels...) = map(x -> onecold(x, labels...), y.data) +onecold(y::OneHotMatrix, labels...) = + mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0) function argmax(xs...) Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax)