Merge #1053
1053: Added Some Loss functions with some doc improvements r=CarloLucibello a=AdarshKumar712 Added the following loss functions with tests: 1. mae 2. mean squared logarithmic error 3. huber loss 4. squared hinge loss 5. dice coeff loss 6. tversky loss Also added some documentation improvements for few other functions. Co-authored-by: Adarsh Kumar <45385384+AdarshKumar712@users.noreply.github.com>
This commit is contained in:
commit
af23a5756c
|
@ -65,7 +65,10 @@ trainmode!
|
|||
|
||||
## Cost Functions
|
||||
```@docs
|
||||
Flux.mae
|
||||
Flux.mse
|
||||
Flux.msle
|
||||
Flux.huber_loss
|
||||
Flux.crossentropy
|
||||
Flux.logitcrossentropy
|
||||
Flux.binarycrossentropy
|
||||
|
@ -73,4 +76,7 @@ Flux.logitbinarycrossentropy
|
|||
Flux.kldivergence
|
||||
Flux.poisson
|
||||
Flux.hinge
|
||||
Flux.squared_hinge
|
||||
Flux.dice_coeff_loss
|
||||
Flux.tversky_loss
|
||||
```
|
||||
|
|
|
@ -1,4 +1,12 @@
|
|||
# Cost functions
|
||||
"""
|
||||
mae(ŷ, y)
|
||||
|
||||
Return the mean of absolute error `sum(abs.(ŷ .- y)) / length(y)`
|
||||
"""
|
||||
mae(ŷ, y) = sum(abs.(ŷ .- y)) * 1 // length(y)
|
||||
|
||||
|
||||
"""
|
||||
mse(ŷ, y)
|
||||
|
||||
|
@ -7,6 +15,36 @@ Return the mean squared error `sum((ŷ .- y).^2) / length(y)`.
|
|||
mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
|
||||
|
||||
|
||||
"""
|
||||
msle(ŷ, y; ϵ=eps(eltype(ŷ)))
|
||||
|
||||
Returns the mean of the squared logarithmic errors `sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) / length(y)`.
|
||||
The `ϵ` term provides numerical stability.
|
||||
|
||||
This error penalizes an under-predicted estimate greater than an over-predicted estimate.
|
||||
"""
|
||||
msle(ŷ, y; ϵ=eps(eltype(ŷ))) = sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) * 1 // length(y)
|
||||
|
||||
|
||||
|
||||
"""
|
||||
huber_loss(ŷ, y; δ=1.0)
|
||||
|
||||
Computes the mean of the Huber loss given the prediction `ŷ` and true values `y`. By default, δ is set to 1.0.
|
||||
|
||||
| 0.5*|ŷ - y|, for |ŷ - y| <= δ
|
||||
Hubber loss = |
|
||||
| δ*(|ŷ - y| - 0.5*δ), otherwise
|
||||
|
||||
[`Huber Loss`](https://en.wikipedia.org/wiki/Huber_loss).
|
||||
"""
|
||||
function huber_loss(ŷ, y; δ=eltype(ŷ)(1))
|
||||
abs_error = abs.(ŷ .- y)
|
||||
temp = abs_error .< δ
|
||||
x = eltype(ŷ)(0.5)
|
||||
hub_loss = sum(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp)) * 1 // length(y)
|
||||
end
|
||||
|
||||
function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Nothing)
|
||||
return -sum(y .* log.(ŷ)) * 1 // size(y, 2)
|
||||
end
|
||||
|
@ -102,10 +140,11 @@ end
|
|||
|
||||
KLDivergence is a measure of how much one probability distribution is different from the other.
|
||||
It is always non-negative and zero only when both the distributions are equal everywhere.
|
||||
|
||||
[KL Divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence).
|
||||
"""
|
||||
function kldivergence(ŷ, y)
|
||||
entropy = sum(y .* log.(y)) *1 //size(y,2)
|
||||
entropy = sum(y .* log.(y)) * 1 //size(y,2)
|
||||
cross_entropy = crossentropy(ŷ, y)
|
||||
return entropy + cross_entropy
|
||||
end
|
||||
|
@ -114,14 +153,50 @@ end
|
|||
poisson(ŷ, y)
|
||||
|
||||
Poisson loss function is a measure of how the predicted distribution diverges from the expected distribution.
|
||||
Returns `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`
|
||||
|
||||
[Poisson Loss](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
|
||||
"""
|
||||
poisson(ŷ, y) = sum(ŷ .- y .* log.(ŷ)) *1 // size(y,2)
|
||||
poisson(ŷ, y) = sum(ŷ .- y .* log.(ŷ)) * 1 // size(y,2)
|
||||
|
||||
"""
|
||||
hinge(ŷ, y)
|
||||
|
||||
Measures the loss given the prediction `ŷ` and true labels `y` (containing 1 or -1).
|
||||
[Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss).
|
||||
Returns `sum((max.(0, 1 .- ŷ .* y))) / size(y, 2)`
|
||||
|
||||
[Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss)
|
||||
See also [`squared_hinge`](@ref).
|
||||
"""
|
||||
hinge(ŷ, y) = sum(max.(0, 1 .- ŷ .* y)) *1 // size(y,2)
|
||||
hinge(ŷ, y) = sum(max.(0, 1 .- ŷ .* y)) * 1 // size(y, 2)
|
||||
|
||||
"""
|
||||
squared_hinge(ŷ, y)
|
||||
|
||||
Computes squared hinge loss given the prediction `ŷ` and true labels `y` (conatining 1 or -1).
|
||||
Returns `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`
|
||||
|
||||
See also [`hinge`](@ref).
|
||||
"""
|
||||
squared_hinge(ŷ, y) = sum((max.(0, 1 .- ŷ .* y)).^2) * 1 // size(y, 2)
|
||||
|
||||
"""
|
||||
dice_coeff_loss(ŷ, y; smooth=1)
|
||||
|
||||
Loss function used in Image Segmentation. Calculates loss based on dice coefficient. Similar to F1_score.
|
||||
Returns `1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth)`
|
||||
|
||||
[V-Net: Fully Convolutional Neural Networks forVolumetric Medical Image Segmentation](https://arxiv.org/pdf/1606.04797v1.pdf)
|
||||
"""
|
||||
dice_coeff_loss(ŷ, y; smooth=eltype(ŷ)(1.0)) = 1 - (2*sum(y .* ŷ) + smooth) / (sum(y.^2) + sum(ŷ.^2) + smooth)
|
||||
|
||||
"""
|
||||
tversky_loss(ŷ, y; β=0.7)
|
||||
|
||||
Used with imbalanced data to give more weightage to False negatives.
|
||||
Larger β weigh recall higher than precision (by placing more emphasis on false negatives)
|
||||
Returns `1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)`
|
||||
|
||||
[Tversky loss function for image segmentation using 3D fully convolutional deep networks](https://arxiv.org/pdf/1706.05721.pdf)
|
||||
"""
|
||||
tversky_loss(ŷ, y; β=eltype(ŷ)(0.7)) = 1 - (sum(y .* ŷ) + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
|
||||
|
|
|
@ -13,6 +13,20 @@ const ϵ = 1e-7
|
|||
@test mse(ŷ, y) ≈ (.1^2 + .9^2)/2
|
||||
end
|
||||
|
||||
@testset "mae" begin
|
||||
@test Flux.mae(ŷ, y) ≈ 1/2
|
||||
end
|
||||
|
||||
@testset "huber_loss" begin
|
||||
@test Flux.huber_loss(ŷ, y) ≈ 0.20500000000000002
|
||||
end
|
||||
|
||||
y = [123.0,456.0,789.0]
|
||||
ŷ = [345.0,332.0,789.0]
|
||||
@testset "msle" begin
|
||||
@test Flux.msle(ŷ, y) ≈ 0.38813985859136585
|
||||
end
|
||||
|
||||
# Now onehot y's
|
||||
y = onehotbatch([1, 1, 0, 0], 0:1)
|
||||
ŷ = [.1 .9; .9 .1; .9 .1; .1 .9]'
|
||||
|
@ -51,31 +65,50 @@ const ϵ = 1e-7
|
|||
end
|
||||
|
||||
y = [1 2 3]
|
||||
y1 = [4.0 5.0 6.0]
|
||||
ŷ = [4.0 5.0 6.0]
|
||||
@testset "kldivergence" begin
|
||||
@test Flux.kldivergence(y, y1) ≈ 4.761838062403337
|
||||
@test Flux.kldivergence(ŷ, y) ≈ -1.7661057888493457
|
||||
@test Flux.kldivergence(y, y) ≈ 0
|
||||
end
|
||||
|
||||
y = [1 2 3 4]
|
||||
y1 = [5.0 6.0 7.0 8.0]
|
||||
ŷ = [5.0 6.0 7.0 8.0]
|
||||
@testset "hinge" begin
|
||||
@test Flux.hinge(y, y1) ≈ 0
|
||||
@test Flux.hinge(ŷ, y) ≈ 0
|
||||
@test Flux.hinge(y, 0.5 .* y) ≈ 0.125
|
||||
end
|
||||
|
||||
@testset "squared_hinge" begin
|
||||
@test Flux.squared_hinge(ŷ, y) ≈ 0
|
||||
@test Flux.squared_hinge(y, 0.5 .* y) ≈ 0.0625
|
||||
end
|
||||
|
||||
y = [0.1 0.2 0.3]
|
||||
y1 = [0.4 0.5 0.6]
|
||||
ŷ = [0.4 0.5 0.6]
|
||||
@testset "poisson" begin
|
||||
@test Flux.poisson(y, y1) ≈ 1.0160455586700767
|
||||
@test Flux.poisson(ŷ, y) ≈ 0.6278353988097339
|
||||
@test Flux.poisson(y, y) ≈ 0.5044459776946685
|
||||
end
|
||||
|
||||
y = [1.0 0.5 0.3 2.4]
|
||||
ŷ = [0 1.4 0.5 1.2]
|
||||
@testset "dice_coeff_loss" begin
|
||||
@test Flux.dice_coeff_loss(ŷ, y) ≈ 0.2799999999999999
|
||||
@test Flux.dice_coeff_loss(y, y) ≈ 0.0
|
||||
end
|
||||
|
||||
@testset "tversky_loss" begin
|
||||
@test Flux.tversky_loss(ŷ, y) ≈ -0.06772009029345383
|
||||
@test Flux.tversky_loss(ŷ, y, β = 0.8) ≈ -0.09490740740740744
|
||||
@test Flux.tversky_loss(y, y) ≈ -0.5576923076923075
|
||||
end
|
||||
|
||||
@testset "no spurious promotions" begin
|
||||
for T in (Float32, Float64)
|
||||
y = rand(T, 2)
|
||||
ŷ = rand(T, 2)
|
||||
for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson)
|
||||
for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson,
|
||||
Flux.mae, Flux.huber_loss, Flux.msle, Flux.squared_hinge, Flux.dice_coeff_loss, Flux.tversky_loss)
|
||||
fwd, back = Flux.pullback(f, ŷ, y)
|
||||
@test fwd isa T
|
||||
@test eltype(back(one(T))[1]) == T
|
||||
|
|
Loading…
Reference in New Issue