diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 37c233e1..e9bf28e0 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -1,3 +1,5 @@ +init_grad(x) = zero(x) + scan(c::Call) = foreach(scan, c.args) function scan(x::Tracked) @@ -5,7 +7,7 @@ function scan(x::Tracked) if ref == 1 scan(x.f) else - isdefined(x, :grad) || (x.grad = zeros(x.data)) + isdefined(x, :grad) || (x.grad = init_grad(x.data)) end return end @@ -19,13 +21,13 @@ back_(f, y, args...) = back(f, args...) back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...) back_(::Call{Void}, y, Δ) = nothing -accum!(x::Tracked, Δ) = (x.grad += Δ) -accum!(x::Tracked{<:AbstractArray}, Δ) = (x.grad .+= Δ) +accum!(x, Δ) = x .+ Δ +accum!(x::AbstractArray, Δ) = (x .+= Δ) function back(x::Tracked, Δ) ref = x.ref -= 1 if isdefined(x, :grad) - accum!(x, Δ) + x.grad = accum!(x.grad, Δ) ref == 0 && back_(x.f, x.data, x.grad) else ref == 0 && back_(x.f, x.data, Δ) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 026d2aeb..f37f8c73 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -61,3 +61,28 @@ for (M, f, arity) in DiffRules.diffrules() end end end + +# Tuples + +struct TrackedTuple{T<:Tuple} + tracker::Tracked{T} +end + +tracker(xs::TrackedTuple) = xs.tracker + +accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ) +init_grad(x::Tuple) = init_grad.(x) + +track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs)) + +function Base.show(io::IO, xs::TrackedTuple) + show(io, data(xs)) + print(io, " (tracked)") +end + +Base.length(x::TrackedTuple) = length(data(x)) + +Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i) + +back(::typeof(getindex), Δ, t, i) = + back(t, ntuple(j -> i == j ? Δ : 0, length(t)))