rename argmax as onecold

This commit is contained in:
boathit 2018-08-23 20:47:43 +08:00
parent 33c901c191
commit 6c97846551
2 changed files with 12 additions and 9 deletions

View File

@ -56,11 +56,14 @@ onehotbatch(ls, labels, unk...) =
import Base:argmax import Base:argmax
argmax(y::AbstractVector, labels) = onecold(y::AbstractVector, labels = 1:length(y)) =
labels[something(findfirst(isequal(maximum(y)), y), 0)] labels[something(findfirst(isequal(maximum(y)), y), 0)]
argmax(y::AbstractMatrix, labels) = onecold(y::AbstractMatrix, labels...) =
dropdims(mapslices(y -> argmax(y, labels), y, dims=1), dims=1) dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
@deprecate argmax(y::AbstractVector, labels::AbstractVector) onecold(y, labels)
@deprecate argmax(y::AbstractMatrix, labels::AbstractVector) onecold(y, labels)
# Ambiguity hack # Ambiguity hack

View File

@ -1,13 +1,13 @@
using Flux:argmax using Flux:onecold
using Test using Test
@testset "argmax" begin @testset "argmax" begin
a = [1, 2, 5, 3.] a = [1, 2, 5, 3.]
A = [1 20 5; 2 7 6; 3 9 10; 2 1 14] A = [1 20 5; 2 7 6; 3 9 10; 2 1 14]
labels = ['A', 'B', 'C', 'D'] labels = ['A', 'B', 'C', 'D']
@test argmax(a) == 3 @test onecold(a) == 3
@test argmax(A) == CartesianIndex(1, 2) @test onecold(A) == [3, 1, 4]
@test argmax(a, labels) == 'C' @test onecold(a, labels) == 'C'
@test argmax(A, labels) == ['C', 'A', 'D'] @test onecold(A, labels) == ['C', 'A', 'D']
end end