2018-02-06 11:32:46 +00:00
|
|
|
|
using NNlib: log_fast, logsoftmax, logσ
|
2017-11-09 15:03:57 +00:00
|
|
|
|
|
2017-08-19 19:52:29 +00:00
|
|
|
|
# Cost functions
|
|
|
|
|
|
2017-08-24 10:40:51 +00:00
|
|
|
|
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
2017-08-19 19:52:29 +00:00
|
|
|
|
|
2017-12-13 15:27:15 +00:00
|
|
|
|
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
|
|
|
|
return -sum(y .* log_fast.(ŷ) .* weight) / size(y, 2)
|
2017-12-05 23:38:15 +00:00
|
|
|
|
end
|
|
|
|
|
|
2017-10-17 16:36:18 +00:00
|
|
|
|
@deprecate logloss(x, y) crossentropy(x, y)
|
2017-10-17 16:57:10 +00:00
|
|
|
|
|
2018-02-06 11:32:46 +00:00
|
|
|
|
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
|
|
|
|
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2)
|
2017-10-17 16:57:10 +00:00
|
|
|
|
end
|
2017-10-10 20:33:37 +00:00
|
|
|
|
|
2018-02-06 11:32:46 +00:00
|
|
|
|
"""
|
|
|
|
|
binarycrossentropy(ŷ, y)
|
|
|
|
|
|
|
|
|
|
Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
|
|
|
|
|
|
|
|
|
|
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
|
|
|
|
|
3-element Array{Float64,1}:
|
|
|
|
|
1.4244
|
|
|
|
|
0.352317
|
|
|
|
|
0.86167
|
|
|
|
|
"""
|
|
|
|
|
binarycrossentropy(ŷ, y) = -y*log_fast(ŷ) - (1 - y)*log_fast(1 - ŷ)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
logitbinarycrossentropy(logŷ, y)
|
|
|
|
|
|
|
|
|
|
`logitbinarycrossentropy(logŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(logŷ), y)`
|
|
|
|
|
but it is more numerically stable.
|
|
|
|
|
|
|
|
|
|
julia> logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0.])
|
|
|
|
|
3-element Array{Float64,1}:
|
|
|
|
|
1.4244
|
|
|
|
|
0.352317
|
|
|
|
|
0.86167
|
|
|
|
|
"""
|
|
|
|
|
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
|
|
|
|
|
|
2017-10-10 20:33:37 +00:00
|
|
|
|
"""
|
2017-10-23 11:53:07 +00:00
|
|
|
|
normalise(x::AbstractVecOrMat)
|
2017-10-10 20:33:37 +00:00
|
|
|
|
|
2017-10-23 11:53:07 +00:00
|
|
|
|
Normalise each column of `x` to mean 0 and standard deviation 1.
|
2017-10-10 20:33:37 +00:00
|
|
|
|
"""
|
2017-10-23 11:53:07 +00:00
|
|
|
|
function normalise(x::AbstractVecOrMat)
|
|
|
|
|
μ′ = mean(x, 1)
|
|
|
|
|
σ′ = std(x, 1, mean = μ′)
|
|
|
|
|
return (x .- μ′) ./ σ′
|
2017-10-10 20:33:37 +00:00
|
|
|
|
end
|