mapreduce for onehotmatrix
This commit is contained in:
parent
2ec35861b5
commit
6825639f79
|
@ -20,23 +20,12 @@ end
|
||||||
|
|
||||||
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
||||||
|
|
||||||
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
|
Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i]
|
||||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
|
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
|
||||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
||||||
|
|
||||||
Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)
|
Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)
|
||||||
|
|
||||||
# handle special case when we want the whole column
|
|
||||||
function Base.getindex(xs::Flux.OneHotMatrix{T}, ot::Union{Base.Slice, Base.OneTo}, i::Int) where {T<:AbstractArray}
|
|
||||||
res = similar(xs, size(xs, 1), 1)
|
|
||||||
if length(ot) == size(xs, 1)
|
|
||||||
res = xs[:,i]
|
|
||||||
else
|
|
||||||
res = xs[1:length(ot),i]
|
|
||||||
end
|
|
||||||
res
|
|
||||||
end
|
|
||||||
|
|
||||||
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
||||||
|
|
||||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
|
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
|
||||||
|
@ -76,7 +65,8 @@ onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
||||||
onecold(y::AbstractMatrix, labels...) =
|
onecold(y::AbstractMatrix, labels...) =
|
||||||
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
|
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
|
||||||
|
|
||||||
onecold(y::OneHotMatrix, labels...) = map(x -> onecold(x, labels...), y.data)
|
onecold(y::OneHotMatrix, labels...) =
|
||||||
|
mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0)
|
||||||
|
|
||||||
function argmax(xs...)
|
function argmax(xs...)
|
||||||
Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax)
|
Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax)
|
||||||
|
|
Loading…
Reference in New Issue