diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 96ed3bcf..4b7771c5 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -28,6 +28,7 @@ mutable struct Tracked{T} end Tracked(f::Call, x) = Tracked{typeof(x)}(f, x) +Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ) track(f::Call, x) = Tracked(f, x) track(f::Call) = track(f, f()) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 3f38d9f0..ab003f90 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -2,11 +2,11 @@ struct TrackedNumber{T<:Number} <: Number tracker::Tracked{T} end -TrackedNumber(x::Number) = TrackedNumber(Tracked(Call(nothing), x)) +TrackedNumber(x::Number) = TrackedNumber(Tracked(Call(nothing), x, zero(x))) tracker(x::TrackedNumber) = x.tracker -track(f::Call, x::Number) = TrackedNumber(Tracked(f, x)) +track(f::Call, x::Number) = TrackedNumber(Tracked(f, x, zero(x))) back!(x::TrackedNumber) = back!(x, 1)