Flux.jl/src/layers/stateless.jl

57 lines
1.7 KiB
Julia
Raw Normal View History

2018-03-01 16:31:20 +00:00
using NNlib: logsoftmax, logσ
2017-11-09 15:03:57 +00:00
2017-08-19 19:52:29 +00:00
# Cost functions
mse(, y) = sum(( .- y).^2) * 1 // length(y)
2017-08-19 19:52:29 +00:00
2017-12-13 15:27:15 +00:00
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
-sum(y .* log.() .* weight) * 1 // size(y, 2)
end
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* logsoftmax(logŷ) .* weight) * 1 // size(y, 2)
2017-10-17 16:57:10 +00:00
end
2017-10-10 20:33:37 +00:00
"""
2018-06-26 18:29:06 +00:00
binarycrossentropy(, y; ϵ=eps())
Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability.
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
3-element Array{Float64,1}:
1.4244
0.352317
0.86167
"""
2018-06-27 05:55:43 +00:00
binarycrossentropy(, y; ϵ=eps()) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ)
"""
logitbinarycrossentropy(logŷ, y)
`logitbinarycrossentropy(logŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(logŷ), y)`
but it is more numerically stable.
julia> logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0.])
3-element Array{Float64,1}:
1.4244
0.352317
0.86167
"""
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
2017-10-10 20:33:37 +00:00
"""
2019-02-08 18:49:53 +00:00
normalise(x::AbstractArray; dims=1)
2017-10-10 20:33:37 +00:00
Normalises x to mean 0 and standard deviation 1, across the dimensions given by dims. Defaults to normalising over columns.
2017-10-10 20:33:37 +00:00
"""
2019-02-08 13:15:37 +00:00
function normalise(x::AbstractArray; dims=1)
μ′ = mean(x, dims = dims)
2019-02-05 13:06:04 +00:00
σ = std(x, dims = dims, mean = μ′, corrected=false)
2017-10-23 11:53:07 +00:00
return (x .- μ′) ./ σ
2017-10-10 20:33:37 +00:00
end
2019-02-08 13:00:32 +00:00
2019-02-08 13:15:37 +00:00
function normalise(x::AbstractArray, dims=1)
2019-02-08 13:00:32 +00:00
Base.depwarn("`normalise(x::AbstractArray, dims)` is deprecated, use `normalise(a, dims=dims)` instead.", :normalise)
normalise(x, dims = dims)
end