2018-03-01 16:31:20 +00:00
|
|
|
|
using NNlib: logsoftmax, logσ
|
2017-11-09 15:03:57 +00:00
|
|
|
|
|
2017-08-19 19:52:29 +00:00
|
|
|
|
# Cost functions
|
|
|
|
|
|
2019-01-06 19:29:30 +00:00
|
|
|
|
mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
|
2017-08-19 19:52:29 +00:00
|
|
|
|
|
2017-12-13 15:27:15 +00:00
|
|
|
|
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
2019-01-06 19:29:30 +00:00
|
|
|
|
-sum(y .* log.(ŷ) .* weight) * 1 // size(y, 2)
|
2017-12-05 23:38:15 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-02-06 11:32:46 +00:00
|
|
|
|
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
2019-01-06 19:29:30 +00:00
|
|
|
|
return -sum(y .* logsoftmax(logŷ) .* weight) * 1 // size(y, 2)
|
2017-10-17 16:57:10 +00:00
|
|
|
|
end
|
2017-10-10 20:33:37 +00:00
|
|
|
|
|
2018-02-06 11:32:46 +00:00
|
|
|
|
"""
|
2018-06-26 18:29:06 +00:00
|
|
|
|
binarycrossentropy(ŷ, y; ϵ=eps(ŷ))
|
2018-02-06 11:32:46 +00:00
|
|
|
|
|
2018-06-26 17:43:16 +00:00
|
|
|
|
Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability.
|
2018-02-06 11:32:46 +00:00
|
|
|
|
|
|
|
|
|
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
|
|
|
|
|
3-element Array{Float64,1}:
|
|
|
|
|
1.4244
|
|
|
|
|
0.352317
|
|
|
|
|
0.86167
|
|
|
|
|
"""
|
2018-06-27 05:55:43 +00:00
|
|
|
|
binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
2018-02-06 11:32:46 +00:00
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
logitbinarycrossentropy(logŷ, y)
|
|
|
|
|
|
|
|
|
|
`logitbinarycrossentropy(logŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(logŷ), y)`
|
|
|
|
|
but it is more numerically stable.
|
|
|
|
|
|
|
|
|
|
julia> logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0.])
|
|
|
|
|
3-element Array{Float64,1}:
|
|
|
|
|
1.4244
|
|
|
|
|
0.352317
|
|
|
|
|
0.86167
|
|
|
|
|
"""
|
|
|
|
|
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
|
|
|
|
|
|
2017-10-10 20:33:37 +00:00
|
|
|
|
"""
|
2019-02-08 18:49:53 +00:00
|
|
|
|
normalise(x::AbstractArray; dims=1)
|
2017-10-10 20:33:37 +00:00
|
|
|
|
|
2019-02-05 11:39:22 +00:00
|
|
|
|
Normalises x to mean 0 and standard deviation 1, across the dimensions given by dims. Defaults to normalising over columns.
|
2017-10-10 20:33:37 +00:00
|
|
|
|
"""
|
2019-02-08 13:15:37 +00:00
|
|
|
|
function normalise(x::AbstractArray; dims=1)
|
2019-02-05 11:39:22 +00:00
|
|
|
|
μ′ = mean(x, dims = dims)
|
2019-02-05 13:06:04 +00:00
|
|
|
|
σ′ = std(x, dims = dims, mean = μ′, corrected=false)
|
2017-10-23 11:53:07 +00:00
|
|
|
|
return (x .- μ′) ./ σ′
|
2017-10-10 20:33:37 +00:00
|
|
|
|
end
|
2019-02-08 13:00:32 +00:00
|
|
|
|
|
2019-02-11 16:11:47 +00:00
|
|
|
|
function normalise(x::AbstractArray, dims)
|
2019-02-08 13:00:32 +00:00
|
|
|
|
Base.depwarn("`normalise(x::AbstractArray, dims)` is deprecated, use `normalise(a, dims=dims)` instead.", :normalise)
|
|
|
|
|
normalise(x, dims = dims)
|
|
|
|
|
end
|
2019-03-11 21:01:42 +00:00
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
Kullback Leibler Divergence(KL Divergence)
|
|
|
|
|
KLDivergence is a measure of how much one probability distribution is different from the other.
|
|
|
|
|
It is always non-negative and zero only when both the distributions are equal everywhere.
|
2019-03-25 21:45:28 +00:00
|
|
|
|
https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
|
2019-03-11 21:01:42 +00:00
|
|
|
|
"""
|
2019-03-25 21:39:48 +00:00
|
|
|
|
function kldivergence(ŷ, y)
|
2019-03-11 21:01:42 +00:00
|
|
|
|
entropy = sum(y .* log.(y)) *1 //size(y,2)
|
|
|
|
|
cross_entropy = crossentropy(ŷ, y)
|
|
|
|
|
return entropy + cross_entropy
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
Poisson Loss function
|
|
|
|
|
Poisson loss function is a measure of how the predicted distribution diverges from the expected distribution.
|
2019-03-25 21:45:28 +00:00
|
|
|
|
https://isaacchanghau.github.io/post/loss_functions/
|
2019-03-11 21:01:42 +00:00
|
|
|
|
"""
|
2019-03-25 21:39:48 +00:00
|
|
|
|
poisson(ŷ, y) = sum(ŷ .- y .* log.(ŷ)) *1 // size(y,2)
|
2019-03-11 21:01:42 +00:00
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
Logcosh Loss function
|
|
|
|
|
"""
|
|
|
|
|
logcosh(ŷ, y) = sum(log.(cosh.(ŷ .- y)))
|
|
|
|
|
|
2019-03-25 21:39:48 +00:00
|
|
|
|
hinge(ŷ, y) = sum(max.(0.0, 1 .- ŷ .* y)) *1 // size(y,2)
|
2019-03-11 21:01:42 +00:00
|
|
|
|
|