Added tests for dice and Tversky loss
This commit is contained in:
parent
980ce72914
commit
3d8965230f
|
@ -90,11 +90,24 @@ const ϵ = 1e-7
|
|||
@test Flux.poisson(y, y) ≈ 0.5044459776946685
|
||||
end
|
||||
|
||||
y = [1.0 0.5 0.3 2.4]
|
||||
y1 = [0 1.4 0.5 1.2]
|
||||
@testset "dice_coeff_loss" begin
|
||||
@test Flux.dice_coeff_loss(y, y1) ≈ 0.2799999999999999
|
||||
@test Flux.dice_coeff_loss(y,y) ≈ 0.0
|
||||
end
|
||||
|
||||
@testset "tversky_loss" begin
|
||||
@test Flux.tversky_loss(y,y1) ≈ 0.028747433264887046
|
||||
@test Flux.tversky_loss(y,y1,0.8) ≈ 0.050200803212851364
|
||||
@test Flux.tversky_loss(y,y) ≈ -0.5576923076923075
|
||||
end
|
||||
|
||||
@testset "no spurious promotions" begin
|
||||
for T in (Float32, Float64)
|
||||
y = rand(T, 2)
|
||||
ŷ = rand(T, 2)
|
||||
for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson,Flux.mae,Flux.huber_loss,Flux.msle,Flux.squared_hinge)
|
||||
for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson,Flux.mae,Flux.huber_loss,Flux.msle,Flux.squared_hinge,Flux.dice_coeff_loss,Flux.tversky_loss)
|
||||
fwd, back = Flux.pullback(f, ŷ, y)
|
||||
@test fwd isa T
|
||||
@test eltype(back(one(T))[1]) == T
|
||||
|
|
Loading…
Reference in New Issue