Added epsilon term to binarycrossentropy

This commit is contained in:
Matthew Kelley 2018-06-26 11:43:16 -06:00
parent bed6d2311e
commit e08fd7a6d2
2 changed files with 10 additions and 7 deletions

View File

@ -15,9 +15,9 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight
end end
""" """
binarycrossentropy(, y) binarycrossentropy(, y; ϵ)
Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`. Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability.
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.]) julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
3-element Array{Float64,1}: 3-element Array{Float64,1}:
@ -25,7 +25,7 @@ Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
0.352317 0.352317
0.86167 0.86167
""" """
binarycrossentropy(, y) = -y*log() - (1 - y)*log(1 - ) binarycrossentropy(, y; ϵ=1e-7) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ)
""" """
logitbinarycrossentropy(logŷ, y) logitbinarycrossentropy(logŷ, y)

View File

@ -2,6 +2,8 @@ using Base.Test
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
σ, binarycrossentropy, logitbinarycrossentropy σ, binarycrossentropy, logitbinarycrossentropy
const ϵ = 1e-7
@testset "losses" begin @testset "losses" begin
# First, regression-style y's # First, regression-style y's
y = [1, 1, 0, 0] y = [1, 1, 0, 0]
@ -40,10 +42,11 @@ using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
logŷ, y = randn(3), rand(3) logŷ, y = randn(3), rand(3)
@testset "binarycrossentropy" begin @testset "binarycrossentropy" begin
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ)) @test binarycrossentropy.(σ.(logŷ), y; ϵ=0) -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ) + 1e-7) - (1 - y).*log.(1 - σ.(logŷ) + 1e-7)
end end
@testset "logitbinarycrossentropy" begin @testset "logitbinarycrossentropy" begin
@test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y) @test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y; ϵ=0)
end end
end end