Added tests for new loss functions

This commit is contained in:
Adarsh Kumar 2020-02-05 23:20:06 +05:30 committed by GitHub
parent 643086c8db
commit 44a977b7a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 2 deletions

View File

@ -12,6 +12,20 @@ const ϵ = 1e-7
@testset "mse" begin
@test mse(ŷ, y) (.1^2 + .9^2)/2
end
@testset "mae" begin
@test Flux.mae(ŷ, y) 1/2
end
@testset "huber_loss" begin
@test Flux.huber_loss(ŷ, y) 0.0012499999999999994
end
y = [123,456,789]
y1 = [345,332,789]
@testset "msle" begin
@test Flux.msle(y1, y) 0.38813985859136585
end
# Now onehot y's
y = onehotbatch([1, 1, 0, 0], 0:1)
@ -64,18 +78,23 @@ const ϵ = 1e-7
@test Flux.hinge(y, 0.5 .* y) 0.125
end
@testset "squared_hinge" begin
@test Flux.squared_hinge(y, y1) 0
@test Flux.squared_hinge(y, 0.5 .* y) 0.0625
end
y = [0.1 0.2 0.3]
y1 = [0.4 0.5 0.6]
@testset "poisson" begin
@test Flux.poisson(y, y1) 1.0160455586700767
@test Flux.poisson(y, y) 0.5044459776946685
end
@testset "no spurious promotions" begin
for T in (Float32, Float64)
y = rand(T, 2)
ŷ = rand(T, 2)
for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson)
for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson,Flux.mae,Flux.huber_loss,Flux.msle,Flux.squared_hinge)
fwd, back = Flux.pullback(f, , y)
@test fwd isa T
@test eltype(back(one(T))[1]) == T