mapreduce for onehotmatrix

This commit is contained in:
Dhairya Gandhi 2019-02-28 09:17:18 +05:30
parent 2ec35861b5
commit 6825639f79
1 changed files with 3 additions and 13 deletions

View File

@ -20,23 +20,12 @@ end
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)
# handle special case when we want the whole column
function Base.getindex(xs::Flux.OneHotMatrix{T}, ot::Union{Base.Slice, Base.OneTo}, i::Int) where {T<:AbstractArray}
res = similar(xs, size(xs, 1), 1)
if length(ot) == size(xs, 1)
res = xs[:,i]
else
res = xs[1:length(ot),i]
end
res
end
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
@ -76,7 +65,8 @@ onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
onecold(y::AbstractMatrix, labels...) =
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
onecold(y::OneHotMatrix, labels...) = map(x -> onecold(x, labels...), y.data)
onecold(y::OneHotMatrix, labels...) =
mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0)
function argmax(xs...)
Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax)