tracked tuples
This commit is contained in:
parent
79e4e25fea
commit
39f7f8fdf3
@ -1,3 +1,5 @@
|
||||
init_grad(x) = zero(x)
|
||||
|
||||
scan(c::Call) = foreach(scan, c.args)
|
||||
|
||||
function scan(x::Tracked)
|
||||
@ -5,7 +7,7 @@ function scan(x::Tracked)
|
||||
if ref == 1
|
||||
scan(x.f)
|
||||
else
|
||||
isdefined(x, :grad) || (x.grad = zeros(x.data))
|
||||
isdefined(x, :grad) || (x.grad = init_grad(x.data))
|
||||
end
|
||||
return
|
||||
end
|
||||
@ -19,13 +21,13 @@ back_(f, y, args...) = back(f, args...)
|
||||
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
|
||||
back_(::Call{Void}, y, Δ) = nothing
|
||||
|
||||
accum!(x::Tracked, Δ) = (x.grad += Δ)
|
||||
accum!(x::Tracked{<:AbstractArray}, Δ) = (x.grad .+= Δ)
|
||||
accum!(x, Δ) = x .+ Δ
|
||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
|
||||
function back(x::Tracked, Δ)
|
||||
ref = x.ref -= 1
|
||||
if isdefined(x, :grad)
|
||||
accum!(x, Δ)
|
||||
x.grad = accum!(x.grad, Δ)
|
||||
ref == 0 && back_(x.f, x.data, x.grad)
|
||||
else
|
||||
ref == 0 && back_(x.f, x.data, Δ)
|
||||
|
@ -61,3 +61,28 @@ for (M, f, arity) in DiffRules.diffrules()
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# Tuples
|
||||
|
||||
struct TrackedTuple{T<:Tuple}
|
||||
tracker::Tracked{T}
|
||||
end
|
||||
|
||||
tracker(xs::TrackedTuple) = xs.tracker
|
||||
|
||||
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
||||
init_grad(x::Tuple) = init_grad.(x)
|
||||
|
||||
track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs))
|
||||
|
||||
function Base.show(io::IO, xs::TrackedTuple)
|
||||
show(io, data(xs))
|
||||
print(io, " (tracked)")
|
||||
end
|
||||
|
||||
Base.length(x::TrackedTuple) = length(data(x))
|
||||
|
||||
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
||||
|
||||
back(::typeof(getindex), Δ, t, i) =
|
||||
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
|
||||
|
Loading…
Reference in New Issue
Block a user