Add weighted_crossentropy
for imbalanced classification problems
This commit is contained in:
parent
cab235a578
commit
41446d547f
@ -4,8 +4,15 @@ using NNlib: log_fast
|
|||||||
|
|
||||||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||||
|
|
||||||
crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) =
|
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat)
|
||||||
-sum(y .* log_fast.(ŷ)) / size(y, 2)
|
return -sum(y .* log_fast.(ŷ)) / size(y, 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
function weighted_crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, w::AbstractVecOrMat)
|
||||||
|
return -sum(y .* log_fast.(ŷ) .* w) / size(y, 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate logloss(x, y) crossentropy(x, y)
|
@deprecate logloss(x, y) crossentropy(x, y)
|
||||||
|
|
||||||
|
26
test/layers/stateless.jl
Normal file
26
test/layers/stateless.jl
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
using Flux: onehotbatch, mse, crossentropy, weighted_crossentropy
|
||||||
|
|
||||||
|
@testset "losses" begin
|
||||||
|
# First, regression-style y's
|
||||||
|
y = [1, 1, 0, 0]
|
||||||
|
y_hat = [.9, .1, .1, .9]
|
||||||
|
|
||||||
|
@testset "mse" begin
|
||||||
|
@test mse(y_hat, y) ≈ (.1^2 + .9^2)/2
|
||||||
|
end
|
||||||
|
|
||||||
|
# Now onehot y's
|
||||||
|
y = onehotbatch([1, 1, 0, 0], 0:1)
|
||||||
|
y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]'
|
||||||
|
y_logloss = 1.203972804325936
|
||||||
|
|
||||||
|
@testset "crossentropy" begin
|
||||||
|
@test crossentropy(y_hat, y) ≈ y_logloss
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "weighted_crossentropy" begin
|
||||||
|
@test weighted_crossentropy(y_hat, y, ones(2)) ≈ y_logloss
|
||||||
|
@test weighted_crossentropy(y_hat, y, [.5, .5]) ≈ y_logloss/2
|
||||||
|
@test weighted_crossentropy(y_hat, y, [2, .5]) ≈ 1.5049660054074199
|
||||||
|
end
|
||||||
|
end
|
@ -5,5 +5,6 @@ using Flux, Base.Test
|
|||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("tracker.jl")
|
include("tracker.jl")
|
||||||
include("layers/normalisation.jl")
|
include("layers/normalisation.jl")
|
||||||
|
include("layers/stateless.jl")
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user