Added tests for new loss functions
This commit is contained in:
parent
643086c8db
commit
44a977b7a4
@ -13,6 +13,20 @@ const ϵ = 1e-7
|
||||
@test mse(ŷ, y) ≈ (.1^2 + .9^2)/2
|
||||
end
|
||||
|
||||
@testset "mae" begin
|
||||
@test Flux.mae(ŷ, y) ≈ 1/2
|
||||
end
|
||||
|
||||
@testset "huber_loss" begin
|
||||
@test Flux.huber_loss(ŷ, y) ≈ 0.0012499999999999994
|
||||
end
|
||||
|
||||
y = [123,456,789]
|
||||
y1 = [345,332,789]
|
||||
@testset "msle" begin
|
||||
@test Flux.msle(y1, y) ≈ 0.38813985859136585
|
||||
end
|
||||
|
||||
# Now onehot y's
|
||||
y = onehotbatch([1, 1, 0, 0], 0:1)
|
||||
ŷ = [.1 .9; .9 .1; .9 .1; .1 .9]'
|
||||
@ -64,6 +78,11 @@ const ϵ = 1e-7
|
||||
@test Flux.hinge(y, 0.5 .* y) ≈ 0.125
|
||||
end
|
||||
|
||||
@testset "squared_hinge" begin
|
||||
@test Flux.squared_hinge(y, y1) ≈ 0
|
||||
@test Flux.squared_hinge(y, 0.5 .* y) ≈ 0.0625
|
||||
end
|
||||
|
||||
y = [0.1 0.2 0.3]
|
||||
y1 = [0.4 0.5 0.6]
|
||||
@testset "poisson" begin
|
||||
@ -75,7 +94,7 @@ const ϵ = 1e-7
|
||||
for T in (Float32, Float64)
|
||||
y = rand(T, 2)
|
||||
ŷ = rand(T, 2)
|
||||
for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson)
|
||||
for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson,Flux.mae,Flux.huber_loss,Flux.msle,Flux.squared_hinge)
|
||||
fwd, back = Flux.pullback(f, ŷ, y)
|
||||
@test fwd isa T
|
||||
@test eltype(back(one(T))[1]) == T
|
||||
|
Loading…
Reference in New Issue
Block a user