fix argmax and add test
This commit is contained in:
parent
930776eb1a
commit
616ed194df
@ -54,11 +54,13 @@ end
|
|||||||
onehotbatch(ls, labels, unk...) =
|
onehotbatch(ls, labels, unk...) =
|
||||||
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
||||||
|
|
||||||
argmax(y::AbstractVector, labels = 1:length(y)) =
|
import Base:argmax
|
||||||
labels[findfirst(y, maximum(y))]
|
|
||||||
|
|
||||||
argmax(y::AbstractMatrix, l...) =
|
argmax(y::AbstractVector, labels) =
|
||||||
squeeze(mapslices(y -> argmax(y, l...), y, 1), 1)
|
labels[something(findfirst(isequal(maximum(y)), y), 0)]
|
||||||
|
|
||||||
|
argmax(y::AbstractMatrix, labels) =
|
||||||
|
dropdims(mapslices(y -> argmax(y, labels), y, dims=1), dims=1)
|
||||||
|
|
||||||
# Ambiguity hack
|
# Ambiguity hack
|
||||||
|
|
||||||
|
13
test/onehot.jl
Normal file
13
test/onehot.jl
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
using Flux:argmax
|
||||||
|
using Test
|
||||||
|
|
||||||
|
@testset "argmax" begin
|
||||||
|
a = [1, 2, 5, 3.]
|
||||||
|
A = [1 20 5; 2 7 6; 3 9 10; 2 1 14]
|
||||||
|
labels = ['A', 'B', 'C', 'D']
|
||||||
|
|
||||||
|
@test argmax(a) == 3
|
||||||
|
@test argmax(A) == CartesianIndex(1, 2)
|
||||||
|
@test argmax(a, labels) == 'C'
|
||||||
|
@test argmax(A, labels) == ['C', 'A', 'D']
|
||||||
|
end
|
@ -24,6 +24,7 @@ insert!(LOAD_PATH, 2, "@v#.#")
|
|||||||
@testset "Flux" begin
|
@testset "Flux" begin
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
|
include("onehot.jl")
|
||||||
include("tracker.jl")
|
include("tracker.jl")
|
||||||
include("layers/normalisation.jl")
|
include("layers/normalisation.jl")
|
||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
|
Loading…
Reference in New Issue
Block a user