Merge #680
680: Added new loss functions. r=thebhatman a=thebhatman I have added the KL Divergence Loss function, Poisson loss function, Logcosh loss, and Hinge loss function. Co-authored-by: Manjunath Bhat <manjunathbhat9920@gmail.com> Co-authored-by: thebhatman <manjunathbhat9920@gmail.com>
This commit is contained in:
commit
d1edd9b16d
|
@ -65,3 +65,15 @@ AlphaDropout
|
||||||
LayerNorm
|
LayerNorm
|
||||||
GroupNorm
|
GroupNorm
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Cost Functions
|
||||||
|
```@docs
|
||||||
|
mse
|
||||||
|
crossentropy
|
||||||
|
logitcrossentropy
|
||||||
|
binarycrossentropy
|
||||||
|
logitbinarycrossentropy
|
||||||
|
kldivergence
|
||||||
|
poisson
|
||||||
|
hinge
|
||||||
|
```
|
||||||
|
|
|
@ -84,3 +84,29 @@ function normalise(x::AbstractArray; dims=1)
|
||||||
σ′ = std(x, dims = dims, mean = μ′, corrected=false)
|
σ′ = std(x, dims = dims, mean = μ′, corrected=false)
|
||||||
return (x .- μ′) ./ σ′
|
return (x .- μ′) ./ σ′
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
kldivergence(ŷ, y)
|
||||||
|
KLDivergence is a measure of how much one probability distribution is different from the other.
|
||||||
|
It is always non-negative and zero only when both the distributions are equal everywhere.
|
||||||
|
[KL Divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence).
|
||||||
|
"""
|
||||||
|
function kldivergence(ŷ, y)
|
||||||
|
entropy = sum(y .* log.(y)) *1 //size(y,2)
|
||||||
|
cross_entropy = crossentropy(ŷ, y)
|
||||||
|
return entropy + cross_entropy
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
poisson(ŷ, y)
|
||||||
|
Poisson loss function is a measure of how the predicted distribution diverges from the expected distribution.
|
||||||
|
[Poisson Loss](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
|
||||||
|
"""
|
||||||
|
poisson(ŷ, y) = sum(ŷ .- y .* log.(ŷ)) *1 // size(y,2)
|
||||||
|
|
||||||
|
"""
|
||||||
|
hinge(ŷ, y)
|
||||||
|
Measures the loss given the prediction ŷ and true labels y(containing 1 or -1).
|
||||||
|
[Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss).
|
||||||
|
"""
|
||||||
|
hinge(ŷ, y) = sum(max.(0, 1 .- ŷ .* y)) *1 // size(y,2)
|
||||||
|
|
|
@ -49,12 +49,33 @@ const ϵ = 1e-7
|
||||||
@testset "logitbinarycrossentropy" begin
|
@testset "logitbinarycrossentropy" begin
|
||||||
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0)
|
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
y = [1 2 3]
|
||||||
|
y1 = [4.0 5.0 6.0]
|
||||||
|
@testset "kldivergence" begin
|
||||||
|
@test Flux.kldivergence(y, y1) ≈ 4.761838062403337
|
||||||
|
@test Flux.kldivergence(y, y) ≈ 0
|
||||||
|
end
|
||||||
|
|
||||||
|
y = [1 2 3 4]
|
||||||
|
y1 = [5.0 6.0 7.0 8.0]
|
||||||
|
@testset "hinge" begin
|
||||||
|
@test Flux.hinge(y, y1) ≈ 0
|
||||||
|
@test Flux.hinge(y, 0.5 .* y) ≈ 0.125
|
||||||
|
end
|
||||||
|
|
||||||
|
y = [0.1 0.2 0.3]
|
||||||
|
y1 = [0.4 0.5 0.6]
|
||||||
|
@testset "poisson" begin
|
||||||
|
@test Flux.poisson(y, y1) ≈ 1.0160455586700767
|
||||||
|
@test Flux.poisson(y, y) ≈ 0.5044459776946685
|
||||||
|
end
|
||||||
|
|
||||||
@testset "no spurious promotions" begin
|
@testset "no spurious promotions" begin
|
||||||
for T in (Float32, Float64)
|
for T in (Float32, Float64)
|
||||||
y = rand(T, 2)
|
y = rand(T, 2)
|
||||||
ŷ = rand(T, 2)
|
ŷ = rand(T, 2)
|
||||||
for f in (mse, crossentropy, logitcrossentropy)
|
for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson)
|
||||||
fwd, back = Flux.pullback(f, ŷ, y)
|
fwd, back = Flux.pullback(f, ŷ, y)
|
||||||
@test fwd isa T
|
@test fwd isa T
|
||||||
@test eltype(back(one(T))[1]) == T
|
@test eltype(back(one(T))[1]) == T
|
||||||
|
|
Loading…
Reference in New Issue