fixes
This commit is contained in:
parent
20ed5c5622
commit
508b392204
@ -27,7 +27,7 @@ The loss corresponding to mean squared logarithmic errors, calculated as
|
|||||||
The `ϵ` term provides numerical stability.
|
The `ϵ` term provides numerical stability.
|
||||||
Penalizes an under-predicted estimate more than an over-predicted estimate.
|
Penalizes an under-predicted estimate more than an over-predicted estimate.
|
||||||
"""
|
"""
|
||||||
msle(ŷ, y; agg=mean, ϵ=eps(eltype(ŷ))) = agg((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2)
|
msle(ŷ, y; agg=mean, ϵ=epseltype(ŷ)) = agg((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2)
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -64,7 +64,7 @@ calculated as
|
|||||||
|
|
||||||
See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
|
See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
|
||||||
"""
|
"""
|
||||||
function crossentropy(ŷ, y; dims=1, agg=mean, ϵ=eps(eltype(ŷ)), weight=nothing)
|
function crossentropy(ŷ, y; dims=1, agg=mean, ϵ=epseltype(ŷ), weight=nothing)
|
||||||
agg(.-wsum(weight, y .* log.(ŷ .+ ϵ); dims=dims))
|
agg(.-wsum(weight, y .* log.(ŷ .+ ϵ); dims=dims))
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -94,7 +94,7 @@ Typically, the prediction `ŷ` is given by the output of a [`sigmoid`](@ref) ac
|
|||||||
|
|
||||||
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
|
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
|
||||||
"""
|
"""
|
||||||
function binarycrossentropy(ŷ, y; agg=mean, ϵ=eps(eltype(ŷ)))
|
function binarycrossentropy(ŷ, y; agg=mean, ϵ=epseltype(ŷ))
|
||||||
agg(@.(-y*log(ŷ+ϵ) - (1-y)*log(1-ŷ+ϵ)))
|
agg(@.(-y*log(ŷ+ϵ) - (1-y)*log(1-ŷ+ϵ)))
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -128,21 +128,21 @@ from the other.
|
|||||||
It is always non-negative and zero only when both the distributions are equal
|
It is always non-negative and zero only when both the distributions are equal
|
||||||
everywhere.
|
everywhere.
|
||||||
"""
|
"""
|
||||||
function kldivergence(ŷ, y; dims=1, agg=mean, ϵ=eps(eltype(ŷ)))
|
function kldivergence(ŷ, y; dims=1, agg=mean, ϵ=epseltype(ŷ))
|
||||||
entropy = agg(sum(y .* log.(y .+ ϵ), dims=dims))
|
entropy = agg(sum(y .* log.(y .+ ϵ), dims=dims))
|
||||||
cross_entropy = crossentropy(ŷ, y; dims=dims, agg=agg, ϵ=ϵ)
|
cross_entropy = crossentropy(ŷ, y; dims=dims, agg=agg, ϵ=ϵ)
|
||||||
return entropy + cross_entropy
|
return entropy + cross_entropy
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
poisson_loss(ŷ, y; agg=mean)
|
poisson_loss(ŷ, y; agg=mean, ϵ=eps(eltype(ŷ))))
|
||||||
|
|
||||||
# Return how much the predicted distribution `ŷ` diverges from the expected Poisson
|
# Return how much the predicted distribution `ŷ` diverges from the expected Poisson
|
||||||
# distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`.
|
# distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`.
|
||||||
REDO
|
REDO
|
||||||
[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
|
[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
|
||||||
"""
|
"""
|
||||||
poisson_loss(ŷ, y; agg=mean) = agg(ŷ .- y .* log.(ŷ))
|
poisson_loss(ŷ, y; agg=mean, ϵ=epseltype(ŷ)) = agg(ŷ .- y .* log.(ŷ .+ ϵ))
|
||||||
|
|
||||||
@deprecate poisson poisson_loss
|
@deprecate poisson poisson_loss
|
||||||
|
|
||||||
|
@ -5,6 +5,8 @@ nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices
|
|||||||
nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of convolution kernels
|
nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of convolution kernels
|
||||||
|
|
||||||
ofeltype(x, y) = convert(float(eltype(x)), y)
|
ofeltype(x, y) = convert(float(eltype(x)), y)
|
||||||
|
epseltype(x) = eps(float(eltype(x)))
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
glorot_uniform(dims...)
|
glorot_uniform(dims...)
|
||||||
|
@ -107,7 +107,7 @@ const ϵ = 1e-7
|
|||||||
for T in (Float32, Float64)
|
for T in (Float32, Float64)
|
||||||
y = rand(T, 2)
|
y = rand(T, 2)
|
||||||
ŷ = rand(T, 2)
|
ŷ = rand(T, 2)
|
||||||
for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson,
|
for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson_loss,
|
||||||
Flux.mae, Flux.huber_loss, Flux.msle, Flux.squared_hinge, Flux.dice_coeff_loss, Flux.tversky_loss)
|
Flux.mae, Flux.huber_loss, Flux.msle, Flux.squared_hinge, Flux.dice_coeff_loss, Flux.tversky_loss)
|
||||||
fwd, back = Flux.pullback(f, ŷ, y)
|
fwd, back = Flux.pullback(f, ŷ, y)
|
||||||
@test fwd isa T
|
@test fwd isa T
|
||||||
|
Loading…
Reference in New Issue
Block a user