tracker predicate tweaks
This commit is contained in:
parent
cf6b930f63
commit
84efbbcc84
@ -1,5 +1,5 @@
|
|||||||
module Tracker
|
module Tracker
|
||||||
import Base: <, ==
|
|
||||||
export TrackedArray, param, back!
|
export TrackedArray, param, back!
|
||||||
|
|
||||||
data(x) = x
|
data(x) = x
|
||||||
@ -41,7 +41,6 @@ TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
|||||||
param(xs) = TrackedArray(AbstractFloat.(xs))
|
param(xs) = TrackedArray(AbstractFloat.(xs))
|
||||||
istracked(x::TrackedArray) = true
|
istracked(x::TrackedArray) = true
|
||||||
data(x::TrackedArray) = x.data
|
data(x::TrackedArray) = x.data
|
||||||
# data(x::TrackedScalar) = x.data[]
|
|
||||||
grad(x::TrackedArray) = x.grad
|
grad(x::TrackedArray) = x.grad
|
||||||
|
|
||||||
# Fallthrough methods
|
# Fallthrough methods
|
||||||
@ -55,17 +54,17 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
|||||||
|
|
||||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||||
|
|
||||||
#to be merged with data in the future
|
value(x) = x
|
||||||
unbox(x::TrackedArray) = data(x)
|
value(x::TrackedArray) = data(x)
|
||||||
unbox(x::TrackedScalar) = data(x)[]
|
value(x::TrackedScalar) = data(x)[]
|
||||||
|
|
||||||
==(x::TrackedArray, y) = unbox(x) == y
|
Base.:(==)(x::TrackedArray, y) = value(x) == y
|
||||||
==(y, x::TrackedArray) = y == unbox(x)
|
Base.:(==)(y, x::TrackedArray) = y == value(x)
|
||||||
==(x::TrackedArray, y::TrackedArray) = unbox(x) == unbox(x)
|
Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(x)
|
||||||
|
|
||||||
<(x::TrackedScalar, y) = unbox(x) < y
|
Base.isless(x::TrackedScalar, y) = isless(value(x), y)
|
||||||
<(x, y::TrackedScalar) = x < unbox(y)
|
Base.isless(x, y::TrackedScalar) = isless(x, value(y))
|
||||||
<(x::TrackedScalar, y::TrackedScalar) = unbox(x) < unbox(y)
|
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
|
||||||
|
|
||||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||||
print(io, "TrackedArray{…,$A}")
|
print(io, "TrackedArray{…,$A}")
|
||||||
|
Loading…
Reference in New Issue
Block a user