This commit is contained in:
CarloLucibello 2020-04-29 11:52:24 +02:00
parent 5f1604d25d
commit 20ed5c5622
9 changed files with 235 additions and 270 deletions

View File

@ -8,8 +8,9 @@ makedocs(modules=[Flux, NNlib],
"Building Models" =>
["Basics" => "models/basics.md",
"Recurrence" => "models/recurrence.md",
"Regularisation" => "models/regularisation.md",
"Model Reference" => "models/layers.md",
"Loss Functions" => "models/losses.md",
"Regularisation" => "models/regularisation.md",
"Advanced Model Building" => "models/advanced.md",
"NNlib" => "models/nnlib.md"],
"Handling Data" =>

View File

@ -67,22 +67,4 @@ Many normalisation layers behave differently under training and inference (testi
```@docs
Flux.testmode!
trainmode!
```
## Cost Functions
```@docs
Flux.mae
Flux.mse
Flux.msle
Flux.huber_loss
Flux.crossentropy
Flux.logitcrossentropy
Flux.binarycrossentropy
Flux.logitbinarycrossentropy
Flux.kldivergence
Flux.poisson
Flux.hinge
Flux.squared_hinge
Flux.dice_coeff_loss
Flux.tversky_loss
```
```

25
docs/src/models/losses.md Normal file
View File

@ -0,0 +1,25 @@
## Loss Functions
Flux provides a large number of common loss functions used for training machine learning models.
Most loss functions in Flux have an optional argument `agg`, denoting the type of aggregation performed over the
batch:
```julia
loss(ŷ, y; agg=mean)
```
```@docs
Flux.mae
Flux.mse
Flux.msle
Flux.huber_loss
Flux.crossentropy
Flux.logitcrossentropy
Flux.binarycrossentropy
Flux.logitbinarycrossentropy
Flux.kldivergence
Flux.poisson_loss
Flux.hinge
Flux.squared_hinge
Flux.dice_coeff_loss
Flux.tversky_loss
```

View File

@ -7,9 +7,10 @@ add the result to the overall loss.
For example, say we have a simple regression.
```julia
using Flux: crossentropy
using Flux
using Flux: logitcrossentropy
m = Dense(10, 5)
loss(x, y) = crossentropy(softmax(m(x)), y)
loss(x, y) = logitcrossentropy(m(x), y)
```
We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`.
@ -18,19 +19,19 @@ We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b
using LinearAlgebra
penalty() = norm(m.W) + norm(m.b)
loss(x, y) = crossentropy(softmax(m(x)), y) + penalty()
loss(x, y) = logitcrossentropy(m(x), y) + penalty()
```
When working with layers, Flux provides the `params` function to grab all
parameters at once. We can easily penalise everything with `sum(norm, params)`.
parameters at once. We can easily penalise everything with `sum`:
```julia
julia> params(m)
julia> Flux.params(m)
2-element Array{Any,1}:
param([0.355408 0.533092; … 0.430459 0.171498])
param([0.0, 0.0, 0.0, 0.0, 0.0])
julia> sum(norm, params(m))
julia> sum(norm, Flux.params(m))
26.01749952921026
```
@ -40,9 +41,9 @@ Here's a larger example with a multi-layer perceptron.
m = Chain(
Dense(28^2, 128, relu),
Dense(128, 32, relu),
Dense(32, 10), softmax)
Dense(32, 10))
loss(x, y) = crossentropy(m(x), y) + sum(norm, params(m))
loss(x, y) = logitcrossentropy(m(x), y) + sum(norm, Flux.params(m))
loss(rand(28^2), rand(10))
```

View File

@ -31,6 +31,7 @@ include("onehot.jl")
include("functor.jl")
include("layers/stateless.jl")
include("layers/losses.jl")
include("layers/basic.jl")
include("layers/conv.jl")
include("layers/recurrent.jl")

190
src/layers/losses.jl Normal file
View File

@ -0,0 +1,190 @@
# Cost functions
"""
mae(, y; agg=mean)
Return the loss corresponding to mean absolute error:
agg(abs.( .- y))
"""
mae(, y; agg=mean) = agg(abs.( .- y))
"""
mse(, y; agg=mean)
Return the loss corresponding to mean square error:
agg(( .- y).^2)
"""
mse(, y; agg=mean) = agg(( .- y).^2)
"""
msle(, y; agg=mean, ϵ=eps(eltype()))
The loss corresponding to mean squared logarithmic errors, calculated as
agg((log.( .+ ϵ) .- log.(y .+ ϵ)).^2)
The `ϵ` term provides numerical stability.
Penalizes an under-predicted estimate more than an over-predicted estimate.
"""
msle(, y; agg=mean, ϵ=eps(eltype())) = agg((log.( .+ ϵ) .- log.(y .+ ϵ)).^2)
"""
huber_loss(, y; δ=1, agg=mean)
Return the mean of the [Huber loss](https://en.wikipedia.org/wiki/Huber_loss)
given the prediction `` and true values `y`.
| 0.5 * | - y|, for | - y| <= δ
Huber loss = |
| δ * (| - y| - 0.5 * δ), otherwise
"""
function huber_loss(, y; agg=mean, δ=ofeltype(, 1))
abs_error = abs.( .- y)
temp = abs_error .< δ
x = ofeltype(, 0.5)
agg(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp))
end
wsum(w::Nothing, x; dims) = sum(x, dims=dims)
wsum(w::Number, x; dims) = w .* sum(x, dims=dims)
wsum(w::AbstractArray, x; dims) = sum( w .* x, dims=dims)
"""
crossentropy(, y; weight=nothing, dims=1, ϵ=eps(eltype()), agg=mean)
Return the cross entropy between the given probability distributions;
calculated as
agg(.-sum(weight .* y .* log.( .+ ϵ); dims=dims))agg=mean,
`weight` can be `nothing`, a number or an array.
`weight=nothing` acts like `weight=1` but is faster.
See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
"""
function crossentropy(, y; dims=1, agg=mean, ϵ=eps(eltype()), weight=nothing)
agg(.-wsum(weight, y .* log.( .+ ϵ); dims=dims))
end
"""
logitcrossentropy(, y; weight=nothing, agg=mean, dims=1)
Return the crossentropy computed after a [`Flux.logsoftmax`](@ref) operation;
calculated as
agg(.-sum(weight .* y .* logsoftmax(; dims=dims); dims=dims))
`logitcrossentropy(ŷ, y)` is mathematically equivalent to
[`Flux.crossentropy(softmax(log.(ŷ)), y)`](@ref) but it is more numerically stable.
See also: [`Flux.crossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
"""
function logitcrossentropy(, y; dims=1, agg=mean, weight=nothing)
agg(.-wsum(weight, y .* logsoftmax(; dims=dims); dims=dims))
end
"""
binarycrossentropy(, y; ϵ=eps())
Return ``-y*\\log( + ϵ) - (1-y)*\\log(1- + ϵ)``. The `ϵ` term provides numerical stability.
Typically, the prediction `` is given by the output of a [`sigmoid`](@ref) activation.
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
"""
function binarycrossentropy(, y; agg=mean, ϵ=eps(eltype()))
agg(@.(-y*log(+ϵ) - (1-y)*log(1-+ϵ)))
end
# Re-definition to fix interaction with CuArrays.
# CuArrays.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
"""
logitbinarycrossentropy(ŷ, y; agg=mean)
`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to
[`Flux.binarycrossentropy(σ(log(ŷ)), y)`](@ref) but it is more numerically stable.
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref)
"""
function logitbinarycrossentropy(, y; agg=mean)
agg(@.((1-y)* - logsigmoid()))
end
# Re-definition to fix interaction with CuArrays.
# CuArrays.@cufunc logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ)
"""
kldivergence(, y; dims=1, agg=mean, ϵ=eps(eltype()))
Return the
[Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence)
between the given arrays interpreted as probability distributions.
KL divergence is a measure of how much one probability distribution is different
from the other.
It is always non-negative and zero only when both the distributions are equal
everywhere.
"""
function kldivergence(, y; dims=1, agg=mean, ϵ=eps(eltype()))
entropy = agg(sum(y .* log.(y .+ ϵ), dims=dims))
cross_entropy = crossentropy(, y; dims=dims, agg=agg, ϵ=ϵ)
return entropy + cross_entropy
end
"""
poisson_loss(, y; agg=mean)
# Return how much the predicted distribution `ŷ` diverges from the expected Poisson
# distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`.
REDO
[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
"""
poisson_loss(, y; agg=mean) = agg( .- y .* log.())
@deprecate poisson poisson_loss
"""
hinge(, y; agg=mean)
Return the [hinge loss](https://en.wikipedia.org/wiki/Hinge_loss) given the
prediction `` and true labels `y` (containing 1 or -1); calculated as
`agg(max.(0, 1 .- ŷ .* y))`.
See also: [`squared_hinge`](@ref)
"""
hinge(, y; agg=mean) = agg(max.(0, 1 .- .* y))
"""
squared_hinge(, y; agg=mean)
Return the squared hinge loss given the prediction `` and true labels `y`
(containing 1 or -1); calculated as `agg((max.(0, 1 .- ŷ .* y)).^2))`.
See also: [`hinge`](@ref)
"""
squared_hinge(, y; agg=mean) = agg((max.(0, 1 .- .* y)).^2)
"""
dice_coeff_loss(, y; smooth=1)
Return a loss based on the dice coefficient.
Used in the [V-Net](https://arxiv.org/pdf/1606.04797v1.pdf) image segmentation
architecture.
Similar to the F1_score. Calculated as:
1 - 2*sum(| .* y| + smooth) / (sum(.^2) + sum(y.^2) + smooth)`
"""
dice_coeff_loss(, y; smooth=ofeltype(, 1.0)) = 1 - (2*sum(y .* ) + smooth) / (sum(y.^2) + sum(.^2) + smooth) #TODO
"""
tversky_loss(, y; β=0.7)
Return the [Tversky loss](https://arxiv.org/pdf/1706.05721.pdf).
Used with imbalanced data to give more weight to false negatives.
Larger β weigh recall higher than precision (by placing more emphasis on false negatives)
Calculated as:
1 - sum(|y .* | + 1) / (sum(y .* + β*(1 .- y) .* + (1 - β)*y .* (1 .- )) + 1)
"""
tversky_loss(, y; β=ofeltype(, 0.7)) = 1 - (sum(y .* ) + 1) / (sum(y .* + β*(1 .- y) .* + (1 - β)*y .* (1 .- )) + 1) #TODO

View File

@ -1,167 +1,3 @@
# Cost functions
"""
mae(, y; agg=mean)
Return the Mean Absolute Error.
l = abs.( .- y)
The results
"""
mae(, y; agg=mean) = agg(abs.( .- y))
"""
mse(, y)
Return the mean squared error between and y; calculated as
`sum((ŷ .- y).^2) / length(y)`.
# Examples
```jldoctest
julia> Flux.mse([0, 2], [1, 1])
1//1
```
"""
mse(, y; agg=mean) = agg(( .- y).^2)
"""
msle(, y; ϵ=eps(eltype()))
Return the mean of the squared logarithmic errors; calculated as
`sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) / length(y)`.
The `ϵ` term provides numerical stability.
Penalizes an under-predicted estimate greater than an over-predicted estimate.
"""
msle(, y; agg=mean, ϵ=eps(eltype())) = agg((log.( .+ ϵ) .- log.(y .+ ϵ)).^2)
"""
huber_loss(, y; δ=1)
Return the mean of the [Huber loss](https://en.wikipedia.org/wiki/Huber_loss)
given the prediction `` and true values `y`.
| 0.5 * | - y|, for | - y| <= δ
Huber loss = |
| δ * (| - y| - 0.5 * δ), otherwise
"""
function huber_loss(, y; agg=mean, δ=ofeltype(, 1))
abs_error = abs.( .- y)
temp = abs_error .< δ
x = ofeltype(, 0.5)
agg(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp))
end
# function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Nothing)
# return -sum(y .* log.(ŷ)) * 1 // size(y, 2)
# end
# function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Number)
# return -sum(y .* log.(ŷ)) .* weight * 1 // size(y, 2)
# end
# function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::AbstractVector)
# return -sum(y .* log.(ŷ) .* weight) * 1 // size(y, 2)
# end
"""
crossentropy(, y; weight = nothing)
Return the cross entropy between the given probability distributions;
calculated as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`.
`weight` can be `Nothing`, a `Number` or an `AbstractVector`.
`weight=nothing` acts like `weight=1` but is faster.
See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
# Examples
```jldoctest
julia> Flux.crossentropy(softmax([-1.1491, 0.8619, 0.3127]), [1, 1, 0])
3.085467254747739
```
"""
# crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _crossentropy(ŷ, y, weight)
function crossentropy(, y; dims=1, agg=mean, ϵ=eps(eltype()))
agg(.-sum(y .* log.( .+ ϵ); dims=dims))
end
"""
logitcrossentropy(, y; weight = 1)
Return the crossentropy computed after a [`Flux.logsoftmax`](@ref) operation;
calculated as `-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)`.
`logitcrossentropy(ŷ, y)` is mathematically equivalent to
[`Flux.crossentropy(softmax(log(ŷ)), y)`](@ref) but it is more numerically stable.
See also: [`Flux.crossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
# Examples
```jldoctest
julia> Flux.logitcrossentropy([-1.1491, 0.8619, 0.3127], [1, 1, 0])
3.085467254747738
```
"""
# function logitcrossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
# return -sum(y .* logsoftmax(ŷ) .* weight) * 1 // size(y, 2)
# end
function logitcrossentropy(, y; dims=1, agg=mean)
agg(.-sum(y .* logsoftmax(; dims=dims); dims=dims))
end
"""
binarycrossentropy(, y; ϵ=eps())
Return ``-y*\\log( + ϵ) - (1-y)*\\log(1- + ϵ)``. The `ϵ` term provides numerical stability.
Typically, the prediction `` is given by the output of a [`sigmoid`](@ref) activation.
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
# Examples
```jldoctest
julia> Flux.binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0])
3-element Array{Float64,1}:
1.424397097347566
0.35231664672364077
0.8616703662235441
```
"""
# binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
function binarycrossentropy(, y; agg=mean, ϵ=eps(eltype()))
agg(@.(-y*log(+ϵ) - (1-y)*log(1-+ϵ)))
end
# Re-definition to fix interaction with CuArrays.
# CuArrays.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
"""
logitbinarycrossentropy(ŷ, y)
`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to
[`Flux.binarycrossentropy(σ(log(ŷ)), y)`](@ref) but it is more numerically stable.
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref)
# Examples
```jldoctest
julia> Flux.logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0])
3-element Array{Float64,1}:
1.4243970973475661
0.35231664672364094
0.8616703662235443
```
"""
# logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ)
function logitcrossentropy(, y; agg=mean)
agg(@.((1-y)* - logsigmoid()))
end
# Re-definition to fix interaction with CuArrays.
# CuArrays.@cufunc logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ)
# TODO normalise over last dimension is typically what you want to do.
# Possible deprecation path: `normalise(x; dims=1)` -> `normalise(x; dims)` -> `normalise(x; dims=size(x)[end])`
"""
@ -197,77 +33,6 @@ function normalise(x::AbstractArray; dims=1, ϵ=ofeltype(x, 1e-6))
return (x .- μ′) ./ (σ.+ ϵ)
end
"""
kldivergence(, y)
Return the
[Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence)
between the given probability distributions.
KL divergence is a measure of how much one probability distribution is different
from the other.
It is always non-negative and zero only when both the distributions are equal
everywhere.
"""
function kldivergence(, y; dims=1, agg=mean, ϵ=eps(eltype()))
entropy = agg(sum(y .* log.(y .+ ϵ), dims=dims))
cross_entropy = crossentropy(, y; dims=dims, agg=agg, ϵ=ϵ)
return entropy + cross_entropy
end
"""
poisson(, y)
# Return how much the predicted distribution `ŷ` diverges from the expected Poisson
# distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`.
REDO
[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
"""
poisson(, y; agg=mean) = agg( .- y .* log.())
"""
hinge(, y)
Return the [hinge loss](https://en.wikipedia.org/wiki/Hinge_loss) given the
prediction `` and true labels `y` (containing 1 or -1); calculated as
`sum(max.(0, 1 .- ŷ .* y)) / size(y, 2)`.
See also: [`squared_hinge`](@ref)
"""
hinge(, y; agg=mean) = agg(max.(0, 1 .- .* y))
"""
squared_hinge(, y)
Return the squared hinge loss given the prediction `` and true labels `y`
(containing 1 or -1); calculated as `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`.
See also: [`hinge`](@ref)
"""
squared_hinge(, y; agg=mean) = agg((max.(0, 1 .- .* y)).^2)
"""
dice_coeff_loss(, y; smooth=1)
Return a loss based on the dice coefficient.
Used in the [V-Net](https://arxiv.org/pdf/1606.04797v1.pdf) image segmentation
architecture.
Similar to the F1_score. Calculated as:
1 - 2*sum(| .* y| + smooth) / (sum(.^2) + sum(y.^2) + smooth)`
"""
dice_coeff_loss(, y; smooth=ofeltype(, 1.0)) = 1 - (2*sum(y .* ) + smooth) / (sum(y.^2) + sum(.^2) + smooth) #TODO
"""
tversky_loss(, y; β=0.7)
Return the [Tversky loss](https://arxiv.org/pdf/1706.05721.pdf).
Used with imbalanced data to give more weight to false negatives.
Larger β weigh recall higher than precision (by placing more emphasis on false negatives)
Calculated as:
1 - sum(|y .* | + 1) / (sum(y .* + β*(1 .- y) .* + (1 - β)*y .* (1 .- )) + 1)
"""
tversky_loss(, y; β=ofeltype(, 0.7)) = 1 - (sum(y .* ) + 1) / (sum(y .* + β*(1 .- y) .* + (1 - β)*y .* (1 .- )) + 1) #TODO
"""
flatten(x::AbstractArray)

View File

@ -33,8 +33,8 @@ cx = gpu(x)
x = [-1.1491, 0.8619, 0.3127]
y = [1, 1, 0.]
@test Flux.binarycrossentropy.(σ.(x),y) Array(Flux.binarycrossentropy.(cu(σ.(x)),cu(y)))
@test Flux.logitbinarycrossentropy.(x,y) Array(Flux.logitbinarycrossentropy.(cu(x),cu(y)))
@test Flux.binarycrossentropy(σ.(x), y) Flux.binarycrossentropy(cu(σ.(x)), cu(y))
@test Flux.logitbinarycrossentropy(x, y) Flux.logitbinarycrossentropy(cu(x), cu(y))
xs = rand(5, 5)
ys = Flux.onehotbatch(1:5,1:5)

View File

@ -56,12 +56,12 @@ const ϵ = 1e-7
logŷ, y = randn(3), rand(3)
@testset "binarycrossentropy" begin
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) -y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ))
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ)))
@test binarycrossentropy(σ.(logŷ), y; ϵ=0) mean(-y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ)))
@test binarycrossentropy(σ.(logŷ), y) mean(-y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ))))
end
@testset "logitbinarycrossentropy" begin
@test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y; ϵ=0)
@test logitbinarycrossentropy(logŷ, y) binarycrossentropy(σ.(logŷ), y; ϵ=0)
end
y = [1 2 3]
@ -86,8 +86,8 @@ const ϵ = 1e-7
y = [0.1 0.2 0.3]
ŷ = [0.4 0.5 0.6]
@testset "poisson" begin
@test Flux.poisson(ŷ, y) 0.6278353988097339
@test Flux.poisson(y, y) 0.5044459776946685
@test Flux.poisson_loss(ŷ, y) 0.6278353988097339
@test Flux.poisson_loss(y, y) 0.5044459776946685
end
y = [1.0 0.5 0.3 2.4]