This commit is contained in:
CarloLucibello 2020-04-29 12:31:59 +02:00
parent 20ed5c5622
commit 508b392204
3 changed files with 9 additions and 7 deletions

View File

@ -27,7 +27,7 @@ The loss corresponding to mean squared logarithmic errors, calculated as
The `ϵ` term provides numerical stability. The `ϵ` term provides numerical stability.
Penalizes an under-predicted estimate more than an over-predicted estimate. Penalizes an under-predicted estimate more than an over-predicted estimate.
""" """
msle(, y; agg=mean, ϵ=eps(eltype())) = agg((log.( .+ ϵ) .- log.(y .+ ϵ)).^2) msle(, y; agg=mean, ϵ=epseltype()) = agg((log.( .+ ϵ) .- log.(y .+ ϵ)).^2)
""" """
@ -64,7 +64,7 @@ calculated as
See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref) See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
""" """
function crossentropy(, y; dims=1, agg=mean, ϵ=eps(eltype()), weight=nothing) function crossentropy(, y; dims=1, agg=mean, ϵ=epseltype(), weight=nothing)
agg(.-wsum(weight, y .* log.( .+ ϵ); dims=dims)) agg(.-wsum(weight, y .* log.( .+ ϵ); dims=dims))
end end
@ -94,7 +94,7 @@ Typically, the prediction `ŷ` is given by the output of a [`sigmoid`](@ref) ac
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref) See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
""" """
function binarycrossentropy(, y; agg=mean, ϵ=eps(eltype())) function binarycrossentropy(, y; agg=mean, ϵ=epseltype())
agg(@.(-y*log(+ϵ) - (1-y)*log(1-+ϵ))) agg(@.(-y*log(+ϵ) - (1-y)*log(1-+ϵ)))
end end
@ -128,21 +128,21 @@ from the other.
It is always non-negative and zero only when both the distributions are equal It is always non-negative and zero only when both the distributions are equal
everywhere. everywhere.
""" """
function kldivergence(, y; dims=1, agg=mean, ϵ=eps(eltype())) function kldivergence(, y; dims=1, agg=mean, ϵ=epseltype())
entropy = agg(sum(y .* log.(y .+ ϵ), dims=dims)) entropy = agg(sum(y .* log.(y .+ ϵ), dims=dims))
cross_entropy = crossentropy(, y; dims=dims, agg=agg, ϵ=ϵ) cross_entropy = crossentropy(, y; dims=dims, agg=agg, ϵ=ϵ)
return entropy + cross_entropy return entropy + cross_entropy
end end
""" """
poisson_loss(, y; agg=mean) poisson_loss(, y; agg=mean, ϵ=eps(eltype())))
# Return how much the predicted distribution `ŷ` diverges from the expected Poisson # Return how much the predicted distribution `ŷ` diverges from the expected Poisson
# distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`. # distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`.
REDO REDO
[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson). [More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
""" """
poisson_loss(, y; agg=mean) = agg( .- y .* log.()) poisson_loss(, y; agg=mean, ϵ=epseltype()) = agg( .- y .* log.( .+ ϵ))
@deprecate poisson poisson_loss @deprecate poisson poisson_loss

View File

@ -5,6 +5,8 @@ nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices
nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of convolution kernels nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of convolution kernels
ofeltype(x, y) = convert(float(eltype(x)), y) ofeltype(x, y) = convert(float(eltype(x)), y)
epseltype(x) = eps(float(eltype(x)))
""" """
glorot_uniform(dims...) glorot_uniform(dims...)

View File

@ -107,7 +107,7 @@ const ϵ = 1e-7
for T in (Float32, Float64) for T in (Float32, Float64)
y = rand(T, 2) y = rand(T, 2)
ŷ = rand(T, 2) ŷ = rand(T, 2)
for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson, for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson_loss,
Flux.mae, Flux.huber_loss, Flux.msle, Flux.squared_hinge, Flux.dice_coeff_loss, Flux.tversky_loss) Flux.mae, Flux.huber_loss, Flux.msle, Flux.squared_hinge, Flux.dice_coeff_loss, Flux.tversky_loss)
fwd, back = Flux.pullback(f, , y) fwd, back = Flux.pullback(f, , y)
@test fwd isa T @test fwd isa T