2018-07-18 13:39:20 +00:00
|
|
|
|
using Test
|
2018-06-26 17:43:16 +00:00
|
|
|
|
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
2018-02-06 11:32:46 +00:00
|
|
|
|
σ, binarycrossentropy, logitbinarycrossentropy
|
2017-12-05 23:38:15 +00:00
|
|
|
|
|
2018-06-26 17:43:16 +00:00
|
|
|
|
const ϵ = 1e-7
|
|
|
|
|
|
2017-12-05 23:38:15 +00:00
|
|
|
|
@testset "losses" begin
|
|
|
|
|
# First, regression-style y's
|
|
|
|
|
y = [1, 1, 0, 0]
|
2018-02-06 11:32:46 +00:00
|
|
|
|
ŷ = [.9, .1, .1, .9]
|
2017-12-05 23:38:15 +00:00
|
|
|
|
|
|
|
|
|
@testset "mse" begin
|
2018-02-06 11:32:46 +00:00
|
|
|
|
@test mse(ŷ, y) ≈ (.1^2 + .9^2)/2
|
2017-12-05 23:38:15 +00:00
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
# Now onehot y's
|
|
|
|
|
y = onehotbatch([1, 1, 0, 0], 0:1)
|
2018-02-06 11:32:46 +00:00
|
|
|
|
ŷ = [.1 .9; .9 .1; .9 .1; .1 .9]'
|
|
|
|
|
v = log(.1 / .9)
|
|
|
|
|
logŷ = [v 0.0; 0.0 v; 0.0 v; v 0.0]'
|
|
|
|
|
lossvalue = 1.203972804325936
|
2017-12-05 23:38:15 +00:00
|
|
|
|
|
|
|
|
|
@testset "crossentropy" begin
|
2018-02-06 11:32:46 +00:00
|
|
|
|
@test crossentropy(ŷ, y) ≈ lossvalue
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
@testset "logitcrossentropy" begin
|
|
|
|
|
@test logitcrossentropy(logŷ, y) ≈ lossvalue
|
2017-12-05 23:38:15 +00:00
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
@testset "weighted_crossentropy" begin
|
2018-02-06 11:32:46 +00:00
|
|
|
|
@test crossentropy(ŷ, y, weight = ones(2)) ≈ lossvalue
|
|
|
|
|
@test crossentropy(ŷ, y, weight = [.5, .5]) ≈ lossvalue/2
|
|
|
|
|
@test crossentropy(ŷ, y, weight = [2, .5]) ≈ 1.5049660054074199
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
@testset "weighted_logitcrossentropy" begin
|
|
|
|
|
@test logitcrossentropy(logŷ, y, weight = ones(2)) ≈ lossvalue
|
|
|
|
|
@test logitcrossentropy(logŷ, y, weight = [.5, .5]) ≈ lossvalue/2
|
|
|
|
|
@test logitcrossentropy(logŷ, y, weight = [2, .5]) ≈ 1.5049660054074199
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
logŷ, y = randn(3), rand(3)
|
|
|
|
|
@testset "binarycrossentropy" begin
|
2018-08-11 11:50:27 +00:00
|
|
|
|
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ))
|
|
|
|
|
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ)))
|
2018-02-06 11:32:46 +00:00
|
|
|
|
end
|
2018-06-26 17:43:16 +00:00
|
|
|
|
|
2018-02-06 11:32:46 +00:00
|
|
|
|
@testset "logitbinarycrossentropy" begin
|
2018-06-26 17:43:16 +00:00
|
|
|
|
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0)
|
2017-12-05 23:38:15 +00:00
|
|
|
|
end
|
2019-01-06 19:29:30 +00:00
|
|
|
|
|
|
|
|
|
@testset "no spurious promotions" begin
|
2019-06-14 17:54:31 +00:00
|
|
|
|
for T in (Float32, Float64)
|
2019-01-06 19:29:30 +00:00
|
|
|
|
y = rand(T, 2)
|
|
|
|
|
ŷ = rand(T, 2)
|
|
|
|
|
for f in (mse, crossentropy, logitcrossentropy)
|
2019-08-19 14:22:50 +00:00
|
|
|
|
fwd, back = Flux.forward(f, ŷ, y)
|
2019-03-08 14:49:28 +00:00
|
|
|
|
@test fwd isa T
|
|
|
|
|
@test eltype(back(one(T))[1]) == T
|
2019-01-06 19:29:30 +00:00
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
end
|
2017-12-05 23:38:15 +00:00
|
|
|
|
end
|