From 41446d547fa5e1a80d6c928fe4a24ce8ae280dc3 Mon Sep 17 00:00:00 2001 From: Elliot Saba Date: Tue, 5 Dec 2017 15:38:15 -0800 Subject: [PATCH 1/2] Add `weighted_crossentropy` for imbalanced classification problems --- src/layers/stateless.jl | 11 +++++++++-- test/layers/stateless.jl | 26 ++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 test/layers/stateless.jl diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index edbdec58..8d675735 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -4,8 +4,15 @@ using NNlib: log_fast mse(ŷ, y) = sum((ŷ .- y).^2)/length(y) -crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) = - -sum(y .* log_fast.(ŷ)) / size(y, 2) +function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) + return -sum(y .* log_fast.(ŷ)) / size(y, 2) +end + +function weighted_crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, w::AbstractVecOrMat) + return -sum(y .* log_fast.(ŷ) .* w) / size(y, 2) +end + + @deprecate logloss(x, y) crossentropy(x, y) diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl new file mode 100644 index 00000000..b7a42841 --- /dev/null +++ b/test/layers/stateless.jl @@ -0,0 +1,26 @@ +using Flux: onehotbatch, mse, crossentropy, weighted_crossentropy + +@testset "losses" begin + # First, regression-style y's + y = [1, 1, 0, 0] + y_hat = [.9, .1, .1, .9] + + @testset "mse" begin + @test mse(y_hat, y) ≈ (.1^2 + .9^2)/2 + end + + # Now onehot y's + y = onehotbatch([1, 1, 0, 0], 0:1) + y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]' + y_logloss = 1.203972804325936 + + @testset "crossentropy" begin + @test crossentropy(y_hat, y) ≈ y_logloss + end + + @testset "weighted_crossentropy" begin + @test weighted_crossentropy(y_hat, y, ones(2)) ≈ y_logloss + @test weighted_crossentropy(y_hat, y, [.5, .5]) ≈ y_logloss/2 + @test weighted_crossentropy(y_hat, y, [2, .5]) ≈ 1.5049660054074199 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index efd1a462..5c6ba549 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,5 +5,6 @@ using Flux, Base.Test include("utils.jl") include("tracker.jl") include("layers/normalisation.jl") +include("layers/stateless.jl") end From e3a688e70646cc832e6a69acdb1efe6cdbe5eb36 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 13 Dec 2017 15:27:15 +0000 Subject: [PATCH 2/2] use kwarg --- src/layers/stateless.jl | 10 ++-------- test/layers/stateless.jl | 8 ++++---- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 8d675735..63c40cb8 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -4,16 +4,10 @@ using NNlib: log_fast mse(ŷ, y) = sum((ŷ .- y).^2)/length(y) -function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) - return -sum(y .* log_fast.(ŷ)) / size(y, 2) +function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) + return -sum(y .* log_fast.(ŷ) .* weight) / size(y, 2) end -function weighted_crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, w::AbstractVecOrMat) - return -sum(y .* log_fast.(ŷ) .* w) / size(y, 2) -end - - - @deprecate logloss(x, y) crossentropy(x, y) function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat) diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl index b7a42841..23304eb1 100644 --- a/test/layers/stateless.jl +++ b/test/layers/stateless.jl @@ -1,4 +1,4 @@ -using Flux: onehotbatch, mse, crossentropy, weighted_crossentropy +using Flux: onehotbatch, mse, crossentropy @testset "losses" begin # First, regression-style y's @@ -19,8 +19,8 @@ using Flux: onehotbatch, mse, crossentropy, weighted_crossentropy end @testset "weighted_crossentropy" begin - @test weighted_crossentropy(y_hat, y, ones(2)) ≈ y_logloss - @test weighted_crossentropy(y_hat, y, [.5, .5]) ≈ y_logloss/2 - @test weighted_crossentropy(y_hat, y, [2, .5]) ≈ 1.5049660054074199 + @test crossentropy(y_hat, y, weight = ones(2)) ≈ y_logloss + @test crossentropy(y_hat, y, weight = [.5, .5]) ≈ y_logloss/2 + @test crossentropy(y_hat, y, weight = [2, .5]) ≈ 1.5049660054074199 end end