Merge pull request #122 from staticfloat/sf/weighted_crossentropy

Add `weighted_crossentropy` for imbalanced classification problems
This commit is contained in:
Mike J Innes 2017-12-13 15:29:54 +00:00 committed by GitHub
commit 9926202408
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 2 deletions

View File

@ -4,8 +4,9 @@ using NNlib: log_fast
mse(, y) = sum(( .- y).^2)/length(y)
crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat) =
-sum(y .* log_fast.()) / size(y, 2)
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* log_fast.() .* weight) / size(y, 2)
end
@deprecate logloss(x, y) crossentropy(x, y)

26
test/layers/stateless.jl Normal file
View File

@ -0,0 +1,26 @@
using Flux: onehotbatch, mse, crossentropy
@testset "losses" begin
# First, regression-style y's
y = [1, 1, 0, 0]
y_hat = [.9, .1, .1, .9]
@testset "mse" begin
@test mse(y_hat, y) (.1^2 + .9^2)/2
end
# Now onehot y's
y = onehotbatch([1, 1, 0, 0], 0:1)
y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]'
y_logloss = 1.203972804325936
@testset "crossentropy" begin
@test crossentropy(y_hat, y) y_logloss
end
@testset "weighted_crossentropy" begin
@test crossentropy(y_hat, y, weight = ones(2)) y_logloss
@test crossentropy(y_hat, y, weight = [.5, .5]) y_logloss/2
@test crossentropy(y_hat, y, weight = [2, .5]) 1.5049660054074199
end
end

View File

@ -5,6 +5,7 @@ using Flux, Base.Test
include("utils.jl")
include("tracker.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
include("optimise.jl")
end