fix tests
This commit is contained in:
parent
654b100ce3
commit
b44ba162b1
|
@ -361,9 +361,9 @@ version = "1.2.11+9"
|
|||
|
||||
[[Zygote]]
|
||||
deps = ["AbstractFFTs", "ArrayLayouts", "DiffRules", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
|
||||
git-tree-sha1 = "08ee0b7796c4c9ce644b9ecc326f3e047486baeb"
|
||||
git-tree-sha1 = "f7b0f77a86d2434abf693e3c0330e4682deed28d"
|
||||
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
version = "0.4.17"
|
||||
version = "0.4.18"
|
||||
|
||||
[[ZygoteRules]]
|
||||
deps = ["MacroTools"]
|
||||
|
|
|
@ -50,7 +50,8 @@ wsum(w::Number, x; dims) = w .* sum(x, dims=dims)
|
|||
wsum(w::AbstractArray, x; dims) = sum( w .* x, dims=dims)
|
||||
|
||||
"""
|
||||
crossentropy(ŷ, y; weight=nothing, dims=1, ϵ=eps(eltype(ŷ)), agg=mean)
|
||||
crossentropy(ŷ, y; weight=nothing, dims=1, ϵ=eps(eltype(ŷ)),
|
||||
logits=false, agg=mean)
|
||||
|
||||
Return the cross entropy between the given probability distributions;
|
||||
calculated as
|
||||
|
@ -60,16 +61,22 @@ calculated as
|
|||
`weight` can be `nothing`, a number or an array.
|
||||
`weight=nothing` acts like `weight=1` but is faster.
|
||||
|
||||
If `logits=true`, the input `̂y` is first fed to a [`softmax`](@ref) layer.
|
||||
|
||||
See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
|
||||
"""
|
||||
function crossentropy(ŷ, y; dims=1, agg=mean, ϵ=epseltype(ŷ), weight=nothing)
|
||||
function crossentropy(ŷ, y; dims=1, agg=mean, ϵ=epseltype(ŷ),
|
||||
weight=nothing, logits=false)
|
||||
if logits
|
||||
return logitcrossentropy(ŷ, y; dims=dims, agg=agg, weight=weight)
|
||||
end
|
||||
agg(.-wsum(weight, y .* log.(ŷ .+ ϵ); dims=dims))
|
||||
end
|
||||
|
||||
"""
|
||||
logitcrossentropy(ŷ, y; weight=nothing, agg=mean, dims=1)
|
||||
|
||||
Return the crossentropy computed after a [`Flux.logsoftmax`](@ref) operation;
|
||||
Return the cross[1.0 0.5 0.3 2.4]entropy computed after a [`Flux.logsoftmax`](@ref) operation;
|
||||
calculated as
|
||||
|
||||
agg(.-sum(weight .* y .* logsoftmax(ŷ; dims=dims); dims=dims))
|
||||
|
@ -84,15 +91,18 @@ function logitcrossentropy(ŷ, y; dims=1, agg=mean, weight=nothing)
|
|||
end
|
||||
|
||||
"""
|
||||
binarycrossentropy(ŷ, y; agg=mean, ϵ=epseltype(ŷ))
|
||||
binarycrossentropy(ŷ, y; agg=mean, ϵ=epseltype(ŷ), logits=false)
|
||||
|
||||
Return ``-y*\\log(ŷ + ϵ) - (1-y)*\\log(1-ŷ + ϵ)``. The `ϵ` term provides numerical stability.
|
||||
|
||||
Typically, the prediction `ŷ` is given by the output of a [`sigmoid`](@ref) activation.
|
||||
|
||||
If `logits=true`, the input `̂y` is first fed to a [`sigmoid`](@ref) activation.
|
||||
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
|
||||
"""
|
||||
function binarycrossentropy(ŷ, y; agg=mean, ϵ=epseltype(ŷ))
|
||||
function binarycrossentropy(ŷ, y; agg=mean, ϵ=epseltype(ŷ), logits=false)
|
||||
if logits
|
||||
return logitbinarycrossentropy(ŷ, y; agg=agg)
|
||||
end
|
||||
agg(@.(-y*log(ŷ+ϵ) - (1-y)*log(1-ŷ+ϵ)))
|
||||
end
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of con
|
|||
ofeltype(x, y) = convert(float(eltype(x)), y)
|
||||
epseltype(x) = eps(float(eltype(x)))
|
||||
|
||||
|
||||
"""
|
||||
glorot_uniform(dims...)
|
||||
|
||||
|
|
|
@ -89,18 +89,18 @@ const ϵ = 1e-7
|
|||
@test Flux.poisson_loss(ŷ, y) ≈ 0.6278353988097339
|
||||
@test Flux.poisson_loss(y, y) ≈ 0.5044459776946685
|
||||
end
|
||||
|
||||
|
||||
y = [1.0 0.5 0.3 2.4]
|
||||
ŷ = [0 1.4 0.5 1.2]
|
||||
@testset "dice_coeff_loss" begin
|
||||
@test Flux.dice_coeff_loss(ŷ, y, dims=1) ≈ 0.2799999999999999
|
||||
@test Flux.dice_coeff_loss(y, y, dims=1) ≈ 0.0
|
||||
@test Flux.dice_coeff_loss(ŷ, y, dims=(1,2)) ≈ 0.2799999999999999
|
||||
@test Flux.dice_coeff_loss(y, y, dims=(1,2)) ≈ 0.0
|
||||
end
|
||||
|
||||
@testset "tversky_loss" begin
|
||||
@test Flux.tversky_loss(ŷ, y, dims=1) ≈ -0.06772009029345383
|
||||
@test Flux.tversky_loss(ŷ, y, dims=1, β = 0.8) ≈ -0.09490740740740744
|
||||
@test Flux.tversky_loss(y, y, dims=1) ≈ -0.5576923076923075
|
||||
@test Flux.tversky_loss(ŷ, y, dims=(1,2)) ≈ 0.036175710594315236
|
||||
@test Flux.tversky_loss(ŷ, y, dims=(1,2), β = 0.8) ≈ 0.06281407035175879
|
||||
@test Flux.tversky_loss(y, y, dims=(1,2)) ≈ -0.6904761904761902
|
||||
end
|
||||
|
||||
@testset "no spurious promotions" begin
|
||||
|
|
Loading…
Reference in New Issue