diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 4b7771c5..cb2547bc 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -44,7 +44,7 @@ include("scalar.jl") include("array.jl") include("numeric.jl") -param(x::Number) = TrackedNumber(float(x)) +param(x::Number) = TrackedReal(float(x)) param(xs::AbstractArray) = TrackedArray(float.(xs)) using DataFlow diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 93ec7bce..96326570 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -237,7 +237,7 @@ end dualify(xs, n) = xs dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs)) -dualify(xs::TrackedNumber, ps) = Dual(data(xs), ps) +dualify(xs::TrackedReal, ps) = Dual(data(xs), ps) function tracked_broadcast(f, args::Vararg{Any,N}) where N dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N})) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index ab003f90..74e26d95 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -1,51 +1,46 @@ -struct TrackedNumber{T<:Number} <: Number +struct TrackedReal{T<:Real} <: Real tracker::Tracked{T} end -TrackedNumber(x::Number) = TrackedNumber(Tracked(Call(nothing), x, zero(x))) +TrackedReal(x::Real) = TrackedReal(Tracked(Call(nothing), x, zero(x))) -tracker(x::TrackedNumber) = x.tracker +tracker(x::TrackedReal) = x.tracker -track(f::Call, x::Number) = TrackedNumber(Tracked(f, x, zero(x))) +track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x))) -back!(x::TrackedNumber) = back!(x, 1) +back!(x::TrackedReal) = back!(x, 1) -function Base.show(io::IO, x::TrackedNumber) +function Base.show(io::IO, x::TrackedReal) show(io, data(x)) print(io, " (tracked)") end -Base.convert(::Type{TrackedNumber{T}}, x::TrackedNumber{T}) where T = x +Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x -Base.convert(::Type{TrackedNumber{T}}, x::TrackedNumber) where T = - TrackedNumber(Tracked(x.tracker.f, convert(T, x.tracker.data))) +Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T = + TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data))) -Base.convert(::Type{TrackedNumber{T}}, x::Number) where T = TrackedNumber(convert(T, x)) +Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x)) -Base.isless(x::TrackedNumber, y::Number) = isless(data(x), y) -Base.isless(x::Number, y::TrackedNumber) = isless(x, data(y)) -Base.isless(x::TrackedNumber, y::TrackedNumber) = isless(data(x), data(y)) - -Base.:(==)(x::TrackedNumber, y::Number) = data(x) == y -Base.:(==)(x::Number, y::TrackedNumber) = x == data(y) -Base.:(==)(x::TrackedNumber, y::TrackedNumber) = data(x) == data(y) +Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y) +Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y) for f in :[isinf, isnan, isfinite].args - @eval Base.$f(x::TrackedNumber) = Base.$f(data(x)) + @eval Base.$f(x::TrackedReal) = Base.$f(data(x)) end -Base.Printf.fix_dec(x::TrackedNumber, n::Int) = Base.Printf.fix_dec(data(x), n) +# Base.Printf.fix_dec(x::TrackedReal, n::Int) = Base.Printf.fix_dec(data(x), n) -Base.promote_rule(::Type{TrackedNumber{S}},::Type{T}) where {S,T} = - TrackedNumber{promote_type(S,T)} +Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} = + TrackedReal{promote_type(S,T)} using DiffRules, SpecialFunctions, NaNMath for (M, f, arity) in DiffRules.diffrules() arity == 1 || continue @eval begin - $M.$f(a::TrackedNumber) = track($M.$f, a) - back(::typeof($M.$f), Δ::Number, a::TrackedNumber) = + $M.$f(a::TrackedReal) = track($M.$f, a) + back(::typeof($M.$f), Δ::Real, a::TrackedReal) = back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a))))) end end @@ -54,10 +49,10 @@ for (M, f, arity) in DiffRules.diffrules() arity == 2 || continue da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b))) @eval begin - $M.$f(a::TrackedNumber, b::TrackedNumber) = track($M.$f, a, b) - $M.$f(a::TrackedNumber, b::Number) = track($M.$f, a, b) - $M.$f(a::Number, b::TrackedNumber) = track($M.$f, a, b) - function back(::typeof($M.$f), Δ::Number, a::Number, b::Number) + $M.$f(a::TrackedReal, b::TrackedReal) = track($M.$f, a, b) + $M.$f(a::TrackedReal, b::Real) = track($M.$f, a, b) + $M.$f(a::Real, b::TrackedReal) = track($M.$f, a, b) + function back(::typeof($M.$f), Δ::Real, a::Real, b::Real) @back(a, Δ * $da) @back(b, Δ * $db) end