tracker predicate tweaks
This commit is contained in:
parent
cf6b930f63
commit
84efbbcc84
@ -1,5 +1,5 @@
|
||||
module Tracker
|
||||
import Base: <, ==
|
||||
|
||||
export TrackedArray, param, back!
|
||||
|
||||
data(x) = x
|
||||
@ -41,7 +41,6 @@ TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
||||
param(xs) = TrackedArray(AbstractFloat.(xs))
|
||||
istracked(x::TrackedArray) = true
|
||||
data(x::TrackedArray) = x.data
|
||||
# data(x::TrackedScalar) = x.data[]
|
||||
grad(x::TrackedArray) = x.grad
|
||||
|
||||
# Fallthrough methods
|
||||
@ -55,17 +54,17 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
#to be merged with data in the future
|
||||
unbox(x::TrackedArray) = data(x)
|
||||
unbox(x::TrackedScalar) = data(x)[]
|
||||
value(x) = x
|
||||
value(x::TrackedArray) = data(x)
|
||||
value(x::TrackedScalar) = data(x)[]
|
||||
|
||||
==(x::TrackedArray, y) = unbox(x) == y
|
||||
==(y, x::TrackedArray) = y == unbox(x)
|
||||
==(x::TrackedArray, y::TrackedArray) = unbox(x) == unbox(x)
|
||||
Base.:(==)(x::TrackedArray, y) = value(x) == y
|
||||
Base.:(==)(y, x::TrackedArray) = y == value(x)
|
||||
Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(x)
|
||||
|
||||
<(x::TrackedScalar, y) = unbox(x) < y
|
||||
<(x, y::TrackedScalar) = x < unbox(y)
|
||||
<(x::TrackedScalar, y::TrackedScalar) = unbox(x) < unbox(y)
|
||||
Base.isless(x::TrackedScalar, y) = isless(value(x), y)
|
||||
Base.isless(x, y::TrackedScalar) = isless(x, value(y))
|
||||
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
|
||||
|
||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||
print(io, "TrackedArray{…,$A}")
|
||||
|
Loading…
Reference in New Issue
Block a user