tracked tuples

This commit is contained in:
Mike J Innes 2018-02-07 22:20:44 +00:00
parent 79e4e25fea
commit 39f7f8fdf3
2 changed files with 31 additions and 4 deletions

View File

@ -1,3 +1,5 @@
init_grad(x) = zero(x)
scan(c::Call) = foreach(scan, c.args)
function scan(x::Tracked)
@ -5,7 +7,7 @@ function scan(x::Tracked)
if ref == 1
scan(x.f)
else
isdefined(x, :grad) || (x.grad = zeros(x.data))
isdefined(x, :grad) || (x.grad = init_grad(x.data))
end
return
end
@ -19,13 +21,13 @@ back_(f, y, args...) = back(f, args...)
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
back_(::Call{Void}, y, Δ) = nothing
accum!(x::Tracked, Δ) = (x.grad += Δ)
accum!(x::Tracked{<:AbstractArray}, Δ) = (x.grad .+= Δ)
accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ)
function back(x::Tracked, Δ)
ref = x.ref -= 1
if isdefined(x, :grad)
accum!(x, Δ)
x.grad = accum!(x.grad, Δ)
ref == 0 && back_(x.f, x.data, x.grad)
else
ref == 0 && back_(x.f, x.data, Δ)

View File

@ -61,3 +61,28 @@ for (M, f, arity) in DiffRules.diffrules()
end
end
end
# Tuples
struct TrackedTuple{T<:Tuple}
tracker::Tracked{T}
end
tracker(xs::TrackedTuple) = xs.tracker
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
init_grad(x::Tuple) = init_grad.(x)
track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs))
function Base.show(io::IO, xs::TrackedTuple)
show(io, data(xs))
print(io, " (tracked)")
end
Base.length(x::TrackedTuple) = length(data(x))
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
back(::typeof(getindex), Δ, t, i) =
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))