2017-09-07 03:09:32 +00:00
|
|
|
scan(x) = nothing
|
2017-09-07 01:21:35 +00:00
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
scan(c::Call) = foreach(scan, c.args)
|
|
|
|
|
|
|
|
function scan(x::TrackedArray)
|
2017-10-18 21:54:58 +00:00
|
|
|
ref = x.ref += 1
|
2017-09-07 03:09:32 +00:00
|
|
|
if ref == 1
|
|
|
|
scan(x.f)
|
|
|
|
else
|
2017-10-18 21:54:58 +00:00
|
|
|
isdefined(x, :grad) || (x.grad = zeros(x.data))
|
2017-09-07 03:09:32 +00:00
|
|
|
end
|
|
|
|
return
|
2017-09-07 01:21:35 +00:00
|
|
|
end
|
|
|
|
|
2017-12-15 02:29:14 +00:00
|
|
|
back_(f, y, args...) = back(f, args...)
|
|
|
|
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
|
|
|
|
back_(::Call{Void}, y, Δ) = nothing
|
2017-09-07 03:09:32 +00:00
|
|
|
|
|
|
|
function back(x::TrackedArray, Δ)
|
2017-10-18 21:54:58 +00:00
|
|
|
ref = x.ref -= 1
|
|
|
|
if isdefined(x, :grad)
|
|
|
|
x.grad .+= Δ
|
2017-12-15 02:29:14 +00:00
|
|
|
ref == 0 && back_(x.f, x.data, x.grad)
|
2017-09-07 03:09:32 +00:00
|
|
|
else
|
2017-12-15 02:29:14 +00:00
|
|
|
ref == 0 && back_(x.f, x.data, Δ)
|
2017-09-07 03:09:32 +00:00
|
|
|
end
|
|
|
|
return
|
|
|
|
end
|
2017-09-07 01:21:35 +00:00
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
macro back(x, Δ)
|
2017-09-07 01:21:35 +00:00
|
|
|
quote
|
|
|
|
x = $(esc(x))
|
2017-09-07 03:09:32 +00:00
|
|
|
istracked(x) && back(x, $(esc(Δ)))
|
2017-09-07 01:21:35 +00:00
|
|
|
end
|
|
|
|
end
|
2017-09-07 03:09:32 +00:00
|
|
|
|
|
|
|
# Interface methods
|
|
|
|
|
|
|
|
function back!(x::TrackedArray, Δ)
|
|
|
|
scan(x)
|
|
|
|
back(x, Δ)
|
|
|
|
end
|
|
|
|
|
|
|
|
back!(x::TrackedScalar) = back!(x, 1)
|