diff --git a/src/tracker/array.jl b/src/tracker/array.jl index ce72755d..35d2c39f 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -1,4 +1,4 @@ -import Base: *, == +import Base: *, ==, ≈ import LinearAlgebra using Statistics @@ -64,6 +64,10 @@ x::TrackedArray == y = data(x) == y y == x::TrackedArray = y == data(x) x::TrackedArray == y::TrackedArray = data(x) == data(y) +x::TrackedArray ≈ y = data(x) ≈ y +y ≈ x::TrackedArray = y ≈ data(x) +x::TrackedArray ≈ y::TrackedArray = data(x) ≈ data(y) + # Array Stdlib Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 9ff1895a..03892c46 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -32,6 +32,7 @@ Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y) Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y) +Base.:(≈)(x::TrackedReal, y::TrackedReal) = data(x) ≈ data(y) Base.eps(x::TrackedReal) = eps(data(x))