fix MethodError for == and ≈

```julia
param([2]).^2 == [4.0]
ERROR: MethodError: ==(::TrackedArray{…,Array{Float64,1}}, ::Array{Float64,1}) is ambiguous. Candidates:
  ==(x::TrackedArray, y) in Main.Flux.Tracker at /Users/jc/.julia/dev/Flux/src/tracker/array.jl:63
  ==(A::AbstractArray, B::AbstractArray) in Base at abstractarray.jl:1686
Possible fix, define
  ==(::TrackedArray, ::AbstractArray)
```
This commit is contained in:
Johnny Chen 2018-08-25 14:51:40 +08:00
parent 7bfe431321
commit 4ac76c35b0
2 changed files with 11 additions and 11 deletions

View File

@ -1,4 +1,4 @@
import Base: *, ==,
import Base: *
import LinearAlgebra
using Statistics
@ -60,13 +60,11 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
x::TrackedArray == y = data(x) == y
y == x::TrackedArray = y == data(x)
x::TrackedArray == y::TrackedArray = data(x) == data(y)
x::TrackedArray y = data(x) y
y x::TrackedArray = y data(x)
x::TrackedArray y::TrackedArray = data(x) data(y)
for op in [:(==), :≈]
@eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y)
@eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y))
@eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y))
end
# Array Stdlib

View File

@ -30,9 +30,11 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
error("Not implemented: convert tracked $S to tracked $T")
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
Base.:()(x::TrackedReal, y::TrackedReal) = data(x) data(y)
for op in [:(==), :≈, :<]
@eval Base.$op(x::TrackedReal, y::Number) = Base.$op(data(x), y)
@eval Base.$op(x::Number, y::TrackedReal) = Base.$op(x, data(y))
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
end
Base.eps(x::TrackedReal) = eps(data(x))