2018-02-07 22:20:44 +00:00
|
|
|
init_grad(x) = zero(x)
|
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
scan(c::Call) = foreach(scan, c.args)
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
function scan(x::Tracked)
|
2017-10-18 21:54:58 +00:00
|
|
|
ref = x.ref += 1
|
2017-09-07 03:09:32 +00:00
|
|
|
if ref == 1
|
|
|
|
scan(x.f)
|
|
|
|
else
|
2018-02-07 22:20:44 +00:00
|
|
|
isdefined(x, :grad) || (x.grad = init_grad(x.data))
|
2017-09-07 03:09:32 +00:00
|
|
|
end
|
|
|
|
return
|
2017-09-07 01:21:35 +00:00
|
|
|
end
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
function scan(x)
|
|
|
|
istracked(x) && scan(tracker(x))
|
|
|
|
return
|
|
|
|
end
|
|
|
|
|
2017-12-15 02:29:14 +00:00
|
|
|
back_(f, y, args...) = back(f, args...)
|
|
|
|
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
|
|
|
|
back_(::Call{Void}, y, Δ) = nothing
|
2017-09-07 03:09:32 +00:00
|
|
|
|
2018-02-07 22:20:44 +00:00
|
|
|
accum!(x, Δ) = x .+ Δ
|
|
|
|
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
2018-02-07 20:39:36 +00:00
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
function back(x::Tracked, Δ)
|
2017-10-18 21:54:58 +00:00
|
|
|
ref = x.ref -= 1
|
|
|
|
if isdefined(x, :grad)
|
2018-02-07 22:20:44 +00:00
|
|
|
x.grad = accum!(x.grad, Δ)
|
2017-12-15 02:29:14 +00:00
|
|
|
ref == 0 && back_(x.f, x.data, x.grad)
|
2017-09-07 03:09:32 +00:00
|
|
|
else
|
2017-12-15 02:29:14 +00:00
|
|
|
ref == 0 && back_(x.f, x.data, Δ)
|
2017-09-07 03:09:32 +00:00
|
|
|
end
|
|
|
|
return
|
|
|
|
end
|
2017-09-07 01:21:35 +00:00
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
back(x, Δ) = back(tracker(x), Δ)
|
2018-02-07 20:39:36 +00:00
|
|
|
back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
|
2018-02-07 17:43:25 +00:00
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
macro back(x, Δ)
|
2017-09-07 01:21:35 +00:00
|
|
|
quote
|
|
|
|
x = $(esc(x))
|
2017-09-07 03:09:32 +00:00
|
|
|
istracked(x) && back(x, $(esc(Δ)))
|
2017-09-07 01:21:35 +00:00
|
|
|
end
|
|
|
|
end
|
2017-09-07 03:09:32 +00:00
|
|
|
|
|
|
|
# Interface methods
|
|
|
|
|
2017-12-15 16:17:45 +00:00
|
|
|
# TODO: if an error occurs in `back` the refcounts will be broken
|
|
|
|
# and `back` will silently fail to update.
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
function back!(x::Tracked, Δ)
|
2017-09-07 03:09:32 +00:00
|
|
|
scan(x)
|
|
|
|
back(x, Δ)
|
|
|
|
end
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
back!(x, Δ) = back!(tracker(x), Δ)
|