Test argument consistency with ŷ and y

This commit is contained in:
Adarsh Kumar 2020-03-02 20:33:12 +05:30 committed by GitHub
parent 2f05094068
commit 92e09e204d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 17 additions and 16 deletions

View File

@ -22,9 +22,9 @@ const ϵ = 1e-7
end end
y = [123.0,456.0,789.0] y = [123.0,456.0,789.0]
y1 = [345.0,332.0,789.0] ŷ = [345.0,332.0,789.0]
@testset "msle" begin @testset "msle" begin
@test Flux.msle(y1, y) 0.38813985859136585 @test Flux.msle(ŷ, y) 0.38813985859136585
end end
# Now onehot y's # Now onehot y's
@ -65,49 +65,50 @@ const ϵ = 1e-7
end end
y = [1 2 3] y = [1 2 3]
y1 = [4.0 5.0 6.0] ŷ = [4.0 5.0 6.0]
@testset "kldivergence" begin @testset "kldivergence" begin
@test Flux.kldivergence(y, y1) 4.761838062403337 @test Flux.kldivergence(ŷ, y) -1.7661057888493457
@test Flux.kldivergence(y, y) 0 @test Flux.kldivergence(y, y) 0
end end
y = [1 2 3 4] y = [1 2 3 4]
y1 = [5.0 6.0 7.0 8.0] ŷ = [5.0 6.0 7.0 8.0]
@testset "hinge" begin @testset "hinge" begin
@test Flux.hinge(y, y1) 0 @test Flux.hinge(ŷ, y) 0
@test Flux.hinge(y, 0.5 .* y) 0.125 @test Flux.hinge(y, 0.5 .* y) 0.125
end end
@testset "squared_hinge" begin @testset "squared_hinge" begin
@test Flux.squared_hinge(y, y1) 0 @test Flux.squared_hinge(ŷ, y) 0
@test Flux.squared_hinge(y, 0.5 .* y) 0.0625 @test Flux.squared_hinge(y, 0.5 .* y) 0.0625
end end
y = [0.1 0.2 0.3] y = [0.1 0.2 0.3]
y1 = [0.4 0.5 0.6] ŷ = [0.4 0.5 0.6]
@testset "poisson" begin @testset "poisson" begin
@test Flux.poisson(y, y1) 1.0160455586700767 @test Flux.poisson(ŷ, y) 0.6278353988097339
@test Flux.poisson(y, y) 0.5044459776946685 @test Flux.poisson(y, y) 0.5044459776946685
end end
y = [1.0 0.5 0.3 2.4] y = [1.0 0.5 0.3 2.4]
y1 = [0 1.4 0.5 1.2] ŷ = [0 1.4 0.5 1.2]
@testset "dice_coeff_loss" begin @testset "dice_coeff_loss" begin
@test Flux.dice_coeff_loss(y, y1) 0.2799999999999999 @test Flux.dice_coeff_loss(ŷ, y) 0.2799999999999999
@test Flux.dice_coeff_loss(y,y) 0.0 @test Flux.dice_coeff_loss(y, y) 0.0
end end
@testset "tversky_loss" begin @testset "tversky_loss" begin
@test Flux.tversky_loss(y,y1) 0.028747433264887046 @test Flux.tversky_loss(ŷ, y) -0.06772009029345383
@test Flux.tversky_loss(y,y1,beta = 0.8) 0.050200803212851364 @test Flux.tversky_loss(ŷ, y, β = 0.8) -0.09490740740740744
@test Flux.tversky_loss(y,y) -0.5576923076923075 @test Flux.tversky_loss(y, y) -0.5576923076923075
end end
@testset "no spurious promotions" begin @testset "no spurious promotions" begin
for T in (Float32, Float64) for T in (Float32, Float64)
y = rand(T, 2) y = rand(T, 2)
ŷ = rand(T, 2) ŷ = rand(T, 2)
for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson,Flux.mae,Flux.huber_loss,Flux.msle,Flux.squared_hinge,Flux.dice_coeff_loss,Flux.tversky_loss) for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson,
Flux.mae, Flux.huber_loss, Flux.msle, Flux.squared_hinge, Flux.dice_coeff_loss, Flux.tversky_loss)
fwd, back = Flux.pullback(f, , y) fwd, back = Flux.pullback(f, , y)
@test fwd isa T @test fwd isa T
@test eltype(back(one(T))[1]) == T @test eltype(back(one(T))[1]) == T