diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 90707ea5..8a481970 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -1,5 +1,5 @@ module Tracker -import Base: <, == + export TrackedArray, param, back! data(x) = x @@ -41,7 +41,6 @@ TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x)) param(xs) = TrackedArray(AbstractFloat.(xs)) istracked(x::TrackedArray) = true data(x::TrackedArray) = x.data -# data(x::TrackedScalar) = x.data[] grad(x::TrackedArray) = x.grad # Fallthrough methods @@ -55,17 +54,17 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) -#to be merged with data in the future -unbox(x::TrackedArray) = data(x) -unbox(x::TrackedScalar) = data(x)[] +value(x) = x +value(x::TrackedArray) = data(x) +value(x::TrackedScalar) = data(x)[] -==(x::TrackedArray, y) = unbox(x) == y -==(y, x::TrackedArray) = y == unbox(x) -==(x::TrackedArray, y::TrackedArray) = unbox(x) == unbox(x) +Base.:(==)(x::TrackedArray, y) = value(x) == y +Base.:(==)(y, x::TrackedArray) = y == value(x) +Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(x) -<(x::TrackedScalar, y) = unbox(x) < y -<(x, y::TrackedScalar) = x < unbox(y) -<(x::TrackedScalar, y::TrackedScalar) = unbox(x) < unbox(y) +Base.isless(x::TrackedScalar, y) = isless(value(x), y) +Base.isless(x, y::TrackedScalar) = isless(x, value(y)) +Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y)) Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = print(io, "TrackedArray{…,$A}")