fix MethodError for == and ≈
```julia param([2]).^2 == [4.0] ERROR: MethodError: ==(::TrackedArray{…,Array{Float64,1}}, ::Array{Float64,1}) is ambiguous. Candidates: ==(x::TrackedArray, y) in Main.Flux.Tracker at /Users/jc/.julia/dev/Flux/src/tracker/array.jl:63 ==(A::AbstractArray, B::AbstractArray) in Base at abstractarray.jl:1686 Possible fix, define ==(::TrackedArray, ::AbstractArray) ```
This commit is contained in:
parent
7bfe431321
commit
4ac76c35b0
@ -1,4 +1,4 @@
|
||||
import Base: *, ==, ≈
|
||||
import Base: *
|
||||
|
||||
import LinearAlgebra
|
||||
using Statistics
|
||||
@ -60,13 +60,11 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
x::TrackedArray == y = data(x) == y
|
||||
y == x::TrackedArray = y == data(x)
|
||||
x::TrackedArray == y::TrackedArray = data(x) == data(y)
|
||||
|
||||
x::TrackedArray ≈ y = data(x) ≈ y
|
||||
y ≈ x::TrackedArray = y ≈ data(x)
|
||||
x::TrackedArray ≈ y::TrackedArray = data(x) ≈ data(y)
|
||||
for op in [:(==), :≈]
|
||||
@eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y)
|
||||
@eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y))
|
||||
@eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y))
|
||||
end
|
||||
|
||||
# Array Stdlib
|
||||
|
||||
|
@ -30,9 +30,11 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
||||
error("Not implemented: convert tracked $S to tracked $T")
|
||||
|
||||
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
|
||||
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
|
||||
Base.:(≈)(x::TrackedReal, y::TrackedReal) = data(x) ≈ data(y)
|
||||
for op in [:(==), :≈, :<]
|
||||
@eval Base.$op(x::TrackedReal, y::Number) = Base.$op(data(x), y)
|
||||
@eval Base.$op(x::Number, y::TrackedReal) = Base.$op(x, data(y))
|
||||
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
|
||||
end
|
||||
|
||||
Base.eps(x::TrackedReal) = eps(data(x))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user