tracker predicate tweaks

This commit is contained in:
Mike J Innes 2017-10-26 12:06:29 +01:00
parent cf6b930f63
commit 84efbbcc84

View File

@ -1,5 +1,5 @@
module Tracker module Tracker
import Base: <, ==
export TrackedArray, param, back! export TrackedArray, param, back!
data(x) = x data(x) = x
@ -41,7 +41,6 @@ TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
param(xs) = TrackedArray(AbstractFloat.(xs)) param(xs) = TrackedArray(AbstractFloat.(xs))
istracked(x::TrackedArray) = true istracked(x::TrackedArray) = true
data(x::TrackedArray) = x.data data(x::TrackedArray) = x.data
# data(x::TrackedScalar) = x.data[]
grad(x::TrackedArray) = x.grad grad(x::TrackedArray) = x.grad
# Fallthrough methods # Fallthrough methods
@ -55,17 +54,17 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
#to be merged with data in the future value(x) = x
unbox(x::TrackedArray) = data(x) value(x::TrackedArray) = data(x)
unbox(x::TrackedScalar) = data(x)[] value(x::TrackedScalar) = data(x)[]
==(x::TrackedArray, y) = unbox(x) == y Base.:(==)(x::TrackedArray, y) = value(x) == y
==(y, x::TrackedArray) = y == unbox(x) Base.:(==)(y, x::TrackedArray) = y == value(x)
==(x::TrackedArray, y::TrackedArray) = unbox(x) == unbox(x) Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(x)
<(x::TrackedScalar, y) = unbox(x) < y Base.isless(x::TrackedScalar, y) = isless(value(x), y)
<(x, y::TrackedScalar) = x < unbox(y) Base.isless(x, y::TrackedScalar) = isless(x, value(y))
<(x::TrackedScalar, y::TrackedScalar) = unbox(x) < unbox(y) Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}") print(io, "TrackedArray{…,$A}")