From 3d8965230fc45f687d943f614dacd154f6212f11 Mon Sep 17 00:00:00 2001 From: Adarsh Kumar <45385384+AdarshKumar712@users.noreply.github.com> Date: Thu, 27 Feb 2020 02:29:39 +0530 Subject: [PATCH] Added tests for dice and Tversky loss --- test/layers/stateless.jl | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl index d038bcda..b7d15634 100644 --- a/test/layers/stateless.jl +++ b/test/layers/stateless.jl @@ -89,12 +89,25 @@ const ϵ = 1e-7 @test Flux.poisson(y, y1) ≈ 1.0160455586700767 @test Flux.poisson(y, y) ≈ 0.5044459776946685 end + + y = [1.0 0.5 0.3 2.4] + y1 = [0 1.4 0.5 1.2] + @testset "dice_coeff_loss" begin + @test Flux.dice_coeff_loss(y, y1) ≈ 0.2799999999999999 + @test Flux.dice_coeff_loss(y,y) ≈ 0.0 + end + + @testset "tversky_loss" begin + @test Flux.tversky_loss(y,y1) ≈ 0.028747433264887046 + @test Flux.tversky_loss(y,y1,0.8) ≈ 0.050200803212851364 + @test Flux.tversky_loss(y,y) ≈ -0.5576923076923075 + end @testset "no spurious promotions" begin for T in (Float32, Float64) y = rand(T, 2) ŷ = rand(T, 2) - for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson,Flux.mae,Flux.huber_loss,Flux.msle,Flux.squared_hinge) + for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson,Flux.mae,Flux.huber_loss,Flux.msle,Flux.squared_hinge,Flux.dice_coeff_loss,Flux.tversky_loss) fwd, back = Flux.pullback(f, ŷ, y) @test fwd isa T @test eltype(back(one(T))[1]) == T