diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index ccd4fe4c..ba80e8a6 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -15,9 +15,9 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight end """ - binarycrossentropy(ŷ, y) + binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) -Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`. +Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability. julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.]) 3-element Array{Float64,1}: @@ -25,7 +25,7 @@ Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`. 0.352317 0.86167 """ -binarycrossentropy(ŷ, y) = -y*log(ŷ) - (1 - y)*log(1 - ŷ) +binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ) """ logitbinarycrossentropy(logŷ, y) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index f94f1647..773943c0 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -31,6 +31,8 @@ Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y) Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y) +Base.eps(x::TrackedReal) = eps(data(x)) + for f in :[isinf, isnan, isfinite].args @eval Base.$f(x::TrackedReal) = Base.$f(data(x)) end diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl index ecfa7014..31a67aa7 100644 --- a/test/layers/stateless.jl +++ b/test/layers/stateless.jl @@ -1,7 +1,9 @@ using Base.Test -using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, +using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, σ, binarycrossentropy, logitbinarycrossentropy +const ϵ = 1e-7 + @testset "losses" begin # First, regression-style y's y = [1, 1, 0, 0] @@ -40,10 +42,11 @@ using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, logŷ, y = randn(3), rand(3) @testset "binarycrossentropy" begin - @test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ)) + @test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ)) + @test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ))) end - + @testset "logitbinarycrossentropy" begin - @test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y) + @test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0) end end