Overload Base.eps() for TrackedReal

This commit is contained in:
Matthew Kelley 2018-06-26 23:55:43 -06:00
parent 0e95be3326
commit 864d72eef5
2 changed files with 3 additions and 2 deletions

View File

@ -25,8 +25,7 @@ Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerica
0.352317
0.86167
"""
# binarycrossentropy(ŷ, y; ϵ=eps(Flux.Tracker.data(ŷ))) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
binarycrossentropy(, y; ϵ=eps(Flux.Tracker.data())) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ)
binarycrossentropy(, y; ϵ=eps()) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ)
"""
logitbinarycrossentropy(logŷ, y)

View File

@ -31,6 +31,8 @@ Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
Base.eps(x::TrackedReal) = eps(data(x))
for f in :[isinf, isnan, isfinite].args
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
end