Call Flux.Tracker.data() on ŷ for bce

This commit is contained in:
Matthew Kelley 2018-06-26 14:48:51 -06:00
parent ed032cdb1e
commit 0e95be3326
2 changed files with 3 additions and 2 deletions

View File

@ -25,7 +25,8 @@ Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerica
0.352317
0.86167
"""
binarycrossentropy(, y; ϵ=eps()) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ)
# binarycrossentropy(ŷ, y; ϵ=eps(Flux.Tracker.data(ŷ))) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
binarycrossentropy(, y; ϵ=eps(Flux.Tracker.data())) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ)
"""
logitbinarycrossentropy(logŷ, y)

View File

@ -43,7 +43,7 @@ const ϵ = 1e-7
logŷ, y = randn(3), rand(3)
@testset "binarycrossentropy" begin
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ) .+ eps.()) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.())
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ) .+ eps.(σ.(log))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(log)))
end
@testset "logitbinarycrossentropy" begin