fixes
This commit is contained in:
parent
20ed5c5622
commit
508b392204
|
@ -27,7 +27,7 @@ The loss corresponding to mean squared logarithmic errors, calculated as
|
|||
The `ϵ` term provides numerical stability.
|
||||
Penalizes an under-predicted estimate more than an over-predicted estimate.
|
||||
"""
|
||||
msle(ŷ, y; agg=mean, ϵ=eps(eltype(ŷ))) = agg((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2)
|
||||
msle(ŷ, y; agg=mean, ϵ=epseltype(ŷ)) = agg((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2)
|
||||
|
||||
|
||||
"""
|
||||
|
@ -64,7 +64,7 @@ calculated as
|
|||
|
||||
See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
|
||||
"""
|
||||
function crossentropy(ŷ, y; dims=1, agg=mean, ϵ=eps(eltype(ŷ)), weight=nothing)
|
||||
function crossentropy(ŷ, y; dims=1, agg=mean, ϵ=epseltype(ŷ), weight=nothing)
|
||||
agg(.-wsum(weight, y .* log.(ŷ .+ ϵ); dims=dims))
|
||||
end
|
||||
|
||||
|
@ -94,7 +94,7 @@ Typically, the prediction `ŷ` is given by the output of a [`sigmoid`](@ref) ac
|
|||
|
||||
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
|
||||
"""
|
||||
function binarycrossentropy(ŷ, y; agg=mean, ϵ=eps(eltype(ŷ)))
|
||||
function binarycrossentropy(ŷ, y; agg=mean, ϵ=epseltype(ŷ))
|
||||
agg(@.(-y*log(ŷ+ϵ) - (1-y)*log(1-ŷ+ϵ)))
|
||||
end
|
||||
|
||||
|
@ -128,21 +128,21 @@ from the other.
|
|||
It is always non-negative and zero only when both the distributions are equal
|
||||
everywhere.
|
||||
"""
|
||||
function kldivergence(ŷ, y; dims=1, agg=mean, ϵ=eps(eltype(ŷ)))
|
||||
function kldivergence(ŷ, y; dims=1, agg=mean, ϵ=epseltype(ŷ))
|
||||
entropy = agg(sum(y .* log.(y .+ ϵ), dims=dims))
|
||||
cross_entropy = crossentropy(ŷ, y; dims=dims, agg=agg, ϵ=ϵ)
|
||||
return entropy + cross_entropy
|
||||
end
|
||||
|
||||
"""
|
||||
poisson_loss(ŷ, y; agg=mean)
|
||||
poisson_loss(ŷ, y; agg=mean, ϵ=eps(eltype(ŷ))))
|
||||
|
||||
# Return how much the predicted distribution `ŷ` diverges from the expected Poisson
|
||||
# distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`.
|
||||
REDO
|
||||
[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
|
||||
"""
|
||||
poisson_loss(ŷ, y; agg=mean) = agg(ŷ .- y .* log.(ŷ))
|
||||
poisson_loss(ŷ, y; agg=mean, ϵ=epseltype(ŷ)) = agg(ŷ .- y .* log.(ŷ .+ ϵ))
|
||||
|
||||
@deprecate poisson poisson_loss
|
||||
|
||||
|
|
|
@ -5,6 +5,8 @@ nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices
|
|||
nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of convolution kernels
|
||||
|
||||
ofeltype(x, y) = convert(float(eltype(x)), y)
|
||||
epseltype(x) = eps(float(eltype(x)))
|
||||
|
||||
|
||||
"""
|
||||
glorot_uniform(dims...)
|
||||
|
|
|
@ -107,7 +107,7 @@ const ϵ = 1e-7
|
|||
for T in (Float32, Float64)
|
||||
y = rand(T, 2)
|
||||
ŷ = rand(T, 2)
|
||||
for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson,
|
||||
for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson_loss,
|
||||
Flux.mae, Flux.huber_loss, Flux.msle, Flux.squared_hinge, Flux.dice_coeff_loss, Flux.tversky_loss)
|
||||
fwd, back = Flux.pullback(f, ŷ, y)
|
||||
@test fwd isa T
|
||||
|
|
Loading…
Reference in New Issue