Added tests for dice and Tversky loss

This commit is contained in:
Adarsh Kumar 2020-02-27 02:29:39 +05:30 committed by GitHub
parent 980ce72914
commit 3d8965230f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 1 deletions

View File

@ -89,12 +89,25 @@ const ϵ = 1e-7
@test Flux.poisson(y, y1) 1.0160455586700767 @test Flux.poisson(y, y1) 1.0160455586700767
@test Flux.poisson(y, y) 0.5044459776946685 @test Flux.poisson(y, y) 0.5044459776946685
end end
y = [1.0 0.5 0.3 2.4]
y1 = [0 1.4 0.5 1.2]
@testset "dice_coeff_loss" begin
@test Flux.dice_coeff_loss(y, y1) 0.2799999999999999
@test Flux.dice_coeff_loss(y,y) 0.0
end
@testset "tversky_loss" begin
@test Flux.tversky_loss(y,y1) 0.028747433264887046
@test Flux.tversky_loss(y,y1,0.8) 0.050200803212851364
@test Flux.tversky_loss(y,y) -0.5576923076923075
end
@testset "no spurious promotions" begin @testset "no spurious promotions" begin
for T in (Float32, Float64) for T in (Float32, Float64)
y = rand(T, 2) y = rand(T, 2)
ŷ = rand(T, 2) ŷ = rand(T, 2)
for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson,Flux.mae,Flux.huber_loss,Flux.msle,Flux.squared_hinge) for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson,Flux.mae,Flux.huber_loss,Flux.msle,Flux.squared_hinge,Flux.dice_coeff_loss,Flux.tversky_loss)
fwd, back = Flux.pullback(f, , y) fwd, back = Flux.pullback(f, , y)
@test fwd isa T @test fwd isa T
@test eltype(back(one(T))[1]) == T @test eltype(back(one(T))[1]) == T