Overload Base.eps() for TrackedReal
This commit is contained in:
parent
0e95be3326
commit
864d72eef5
@ -25,8 +25,7 @@ Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerica
|
|||||||
0.352317
|
0.352317
|
||||||
0.86167
|
0.86167
|
||||||
"""
|
"""
|
||||||
# binarycrossentropy(ŷ, y; ϵ=eps(Flux.Tracker.data(ŷ))) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
||||||
binarycrossentropy(ŷ, y; ϵ=eps(Flux.Tracker.data(ŷ))) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
logitbinarycrossentropy(logŷ, y)
|
logitbinarycrossentropy(logŷ, y)
|
||||||
|
@ -31,6 +31,8 @@ Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
|||||||
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
|
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
|
||||||
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
|
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
|
||||||
|
|
||||||
|
Base.eps(x::TrackedReal) = eps(data(x))
|
||||||
|
|
||||||
for f in :[isinf, isnan, isfinite].args
|
for f in :[isinf, isnan, isfinite].args
|
||||||
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
|
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user