Merge pull request #306 from maetshju/pull-request/e08fd7a6
Add epsilon term to binarycrossentropy
This commit is contained in:
commit
d76e790818
@ -15,9 +15,9 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight
|
||||
end
|
||||
|
||||
"""
|
||||
binarycrossentropy(ŷ, y)
|
||||
binarycrossentropy(ŷ, y; ϵ=eps(ŷ))
|
||||
|
||||
Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
|
||||
Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability.
|
||||
|
||||
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
|
||||
3-element Array{Float64,1}:
|
||||
@ -25,7 +25,7 @@ Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
|
||||
0.352317
|
||||
0.86167
|
||||
"""
|
||||
binarycrossentropy(ŷ, y) = -y*log(ŷ) - (1 - y)*log(1 - ŷ)
|
||||
binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
||||
|
||||
"""
|
||||
logitbinarycrossentropy(logŷ, y)
|
||||
|
@ -31,6 +31,8 @@ Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
||||
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
|
||||
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
|
||||
|
||||
Base.eps(x::TrackedReal) = eps(data(x))
|
||||
|
||||
for f in :[isinf, isnan, isfinite].args
|
||||
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
|
||||
end
|
||||
|
@ -1,7 +1,9 @@
|
||||
using Base.Test
|
||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||
σ, binarycrossentropy, logitbinarycrossentropy
|
||||
|
||||
const ϵ = 1e-7
|
||||
|
||||
@testset "losses" begin
|
||||
# First, regression-style y's
|
||||
y = [1, 1, 0, 0]
|
||||
@ -40,10 +42,11 @@ using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||
|
||||
logŷ, y = randn(3), rand(3)
|
||||
@testset "binarycrossentropy" begin
|
||||
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
|
||||
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
|
||||
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ)))
|
||||
end
|
||||
|
||||
|
||||
@testset "logitbinarycrossentropy" begin
|
||||
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y)
|
||||
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0)
|
||||
end
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user