2017-12-13 15:27:15 +00:00
|
|
|
using Flux: onehotbatch, mse, crossentropy
|
2017-12-05 23:38:15 +00:00
|
|
|
|
|
|
|
@testset "losses" begin
|
|
|
|
# First, regression-style y's
|
|
|
|
y = [1, 1, 0, 0]
|
|
|
|
y_hat = [.9, .1, .1, .9]
|
|
|
|
|
|
|
|
@testset "mse" begin
|
|
|
|
@test mse(y_hat, y) ≈ (.1^2 + .9^2)/2
|
|
|
|
end
|
|
|
|
|
|
|
|
# Now onehot y's
|
|
|
|
y = onehotbatch([1, 1, 0, 0], 0:1)
|
|
|
|
y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]'
|
|
|
|
y_logloss = 1.203972804325936
|
|
|
|
|
|
|
|
@testset "crossentropy" begin
|
|
|
|
@test crossentropy(y_hat, y) ≈ y_logloss
|
|
|
|
end
|
|
|
|
|
|
|
|
@testset "weighted_crossentropy" begin
|
2017-12-13 15:27:15 +00:00
|
|
|
@test crossentropy(y_hat, y, weight = ones(2)) ≈ y_logloss
|
|
|
|
@test crossentropy(y_hat, y, weight = [.5, .5]) ≈ y_logloss/2
|
|
|
|
@test crossentropy(y_hat, y, weight = [2, .5]) ≈ 1.5049660054074199
|
2017-12-05 23:38:15 +00:00
|
|
|
end
|
|
|
|
end
|