diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 35d2c39f..923b925c 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -1,4 +1,4 @@ -import Base: *, ==, ≈ +import Base: * import LinearAlgebra using Statistics @@ -60,13 +60,11 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) -x::TrackedArray == y = data(x) == y -y == x::TrackedArray = y == data(x) -x::TrackedArray == y::TrackedArray = data(x) == data(y) - -x::TrackedArray ≈ y = data(x) ≈ y -y ≈ x::TrackedArray = y ≈ data(x) -x::TrackedArray ≈ y::TrackedArray = data(x) ≈ data(y) +for op in [:(==), :≈] + @eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y) + @eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y)) + @eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y)) +end # Array Stdlib diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 03892c46..9e987333 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -30,9 +30,11 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = error("Not implemented: convert tracked $S to tracked $T") -Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y) -Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y) -Base.:(≈)(x::TrackedReal, y::TrackedReal) = data(x) ≈ data(y) +for op in [:(==), :≈, :<] + @eval Base.$op(x::TrackedReal, y::Number) = Base.$op(data(x), y) + @eval Base.$op(x::Number, y::TrackedReal) = Base.$op(x, data(y)) + @eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y)) +end Base.eps(x::TrackedReal) = eps(data(x))