Add weighted_crossentropy for imbalanced classification problems

This commit is contained in:
Elliot Saba 2017-12-05 15:38:15 -08:00
parent cab235a578
commit 41446d547f
3 changed files with 36 additions and 2 deletions

View File

@ -4,8 +4,15 @@ using NNlib: log_fast
mse(, y) = sum(( .- y).^2)/length(y) mse(, y) = sum(( .- y).^2)/length(y)
crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat) = function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat)
-sum(y .* log_fast.()) / size(y, 2) return -sum(y .* log_fast.()) / size(y, 2)
end
function weighted_crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat, w::AbstractVecOrMat)
return -sum(y .* log_fast.() .* w) / size(y, 2)
end
@deprecate logloss(x, y) crossentropy(x, y) @deprecate logloss(x, y) crossentropy(x, y)

26
test/layers/stateless.jl Normal file
View File

@ -0,0 +1,26 @@
using Flux: onehotbatch, mse, crossentropy, weighted_crossentropy
@testset "losses" begin
# First, regression-style y's
y = [1, 1, 0, 0]
y_hat = [.9, .1, .1, .9]
@testset "mse" begin
@test mse(y_hat, y) (.1^2 + .9^2)/2
end
# Now onehot y's
y = onehotbatch([1, 1, 0, 0], 0:1)
y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]'
y_logloss = 1.203972804325936
@testset "crossentropy" begin
@test crossentropy(y_hat, y) y_logloss
end
@testset "weighted_crossentropy" begin
@test weighted_crossentropy(y_hat, y, ones(2)) y_logloss
@test weighted_crossentropy(y_hat, y, [.5, .5]) y_logloss/2
@test weighted_crossentropy(y_hat, y, [2, .5]) 1.5049660054074199
end
end

View File

@ -5,5 +5,6 @@ using Flux, Base.Test
include("utils.jl") include("utils.jl")
include("tracker.jl") include("tracker.jl")
include("layers/normalisation.jl") include("layers/normalisation.jl")
include("layers/stateless.jl")
end end