tracked tuples
This commit is contained in:
parent
79e4e25fea
commit
39f7f8fdf3
@ -1,3 +1,5 @@
|
|||||||
|
init_grad(x) = zero(x)
|
||||||
|
|
||||||
scan(c::Call) = foreach(scan, c.args)
|
scan(c::Call) = foreach(scan, c.args)
|
||||||
|
|
||||||
function scan(x::Tracked)
|
function scan(x::Tracked)
|
||||||
@ -5,7 +7,7 @@ function scan(x::Tracked)
|
|||||||
if ref == 1
|
if ref == 1
|
||||||
scan(x.f)
|
scan(x.f)
|
||||||
else
|
else
|
||||||
isdefined(x, :grad) || (x.grad = zeros(x.data))
|
isdefined(x, :grad) || (x.grad = init_grad(x.data))
|
||||||
end
|
end
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
@ -19,13 +21,13 @@ back_(f, y, args...) = back(f, args...)
|
|||||||
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
|
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
|
||||||
back_(::Call{Void}, y, Δ) = nothing
|
back_(::Call{Void}, y, Δ) = nothing
|
||||||
|
|
||||||
accum!(x::Tracked, Δ) = (x.grad += Δ)
|
accum!(x, Δ) = x .+ Δ
|
||||||
accum!(x::Tracked{<:AbstractArray}, Δ) = (x.grad .+= Δ)
|
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||||
|
|
||||||
function back(x::Tracked, Δ)
|
function back(x::Tracked, Δ)
|
||||||
ref = x.ref -= 1
|
ref = x.ref -= 1
|
||||||
if isdefined(x, :grad)
|
if isdefined(x, :grad)
|
||||||
accum!(x, Δ)
|
x.grad = accum!(x.grad, Δ)
|
||||||
ref == 0 && back_(x.f, x.data, x.grad)
|
ref == 0 && back_(x.f, x.data, x.grad)
|
||||||
else
|
else
|
||||||
ref == 0 && back_(x.f, x.data, Δ)
|
ref == 0 && back_(x.f, x.data, Δ)
|
||||||
|
@ -61,3 +61,28 @@ for (M, f, arity) in DiffRules.diffrules()
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Tuples
|
||||||
|
|
||||||
|
struct TrackedTuple{T<:Tuple}
|
||||||
|
tracker::Tracked{T}
|
||||||
|
end
|
||||||
|
|
||||||
|
tracker(xs::TrackedTuple) = xs.tracker
|
||||||
|
|
||||||
|
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
||||||
|
init_grad(x::Tuple) = init_grad.(x)
|
||||||
|
|
||||||
|
track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs))
|
||||||
|
|
||||||
|
function Base.show(io::IO, xs::TrackedTuple)
|
||||||
|
show(io, data(xs))
|
||||||
|
print(io, " (tracked)")
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.length(x::TrackedTuple) = length(data(x))
|
||||||
|
|
||||||
|
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
||||||
|
|
||||||
|
back(::typeof(getindex), Δ, t, i) =
|
||||||
|
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
|
||||||
|
Loading…
Reference in New Issue
Block a user