From 4ac76c35b0ede5d9c7dc1134f732190543eb499f Mon Sep 17 00:00:00 2001 From: Johnny Chen Date: Sat, 25 Aug 2018 14:51:40 +0800 Subject: [PATCH] =?UTF-8?q?fix=20MethodError=20for=20=3D=3D=20and=20?= =?UTF-8?q?=E2=89=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ```julia param([2]).^2 == [4.0] ERROR: MethodError: ==(::TrackedArray{…,Array{Float64,1}}, ::Array{Float64,1}) is ambiguous. Candidates: ==(x::TrackedArray, y) in Main.Flux.Tracker at /Users/jc/.julia/dev/Flux/src/tracker/array.jl:63 ==(A::AbstractArray, B::AbstractArray) in Base at abstractarray.jl:1686 Possible fix, define ==(::TrackedArray, ::AbstractArray) ``` --- src/tracker/array.jl | 14 ++++++-------- src/tracker/scalar.jl | 8 +++++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 35d2c39f..923b925c 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -1,4 +1,4 @@ -import Base: *, ==, ≈ +import Base: * import LinearAlgebra using Statistics @@ -60,13 +60,11 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) -x::TrackedArray == y = data(x) == y -y == x::TrackedArray = y == data(x) -x::TrackedArray == y::TrackedArray = data(x) == data(y) - -x::TrackedArray ≈ y = data(x) ≈ y -y ≈ x::TrackedArray = y ≈ data(x) -x::TrackedArray ≈ y::TrackedArray = data(x) ≈ data(y) +for op in [:(==), :≈] + @eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y) + @eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y)) + @eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y)) +end # Array Stdlib diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 03892c46..9e987333 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -30,9 +30,11 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = error("Not implemented: convert tracked $S to tracked $T") -Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y) -Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y) -Base.:(≈)(x::TrackedReal, y::TrackedReal) = data(x) ≈ data(y) +for op in [:(==), :≈, :<] + @eval Base.$op(x::TrackedReal, y::Number) = Base.$op(data(x), y) + @eval Base.$op(x::Number, y::TrackedReal) = Base.$op(x, data(y)) + @eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y)) +end Base.eps(x::TrackedReal) = eps(data(x))