Flux.jl/src/tracker/back.jl

45 lines
774 B
Julia
Raw Normal View History

2017-09-07 03:09:32 +00:00
scan(x) = nothing
2017-09-07 01:21:35 +00:00
2017-09-07 03:09:32 +00:00
scan(c::Call) = foreach(scan, c.args)
function scan(x::TrackedArray)
2017-10-18 21:54:58 +00:00
ref = x.ref += 1
2017-09-07 03:09:32 +00:00
if ref == 1
scan(x.f)
else
2017-10-18 21:54:58 +00:00
isdefined(x, :grad) || (x.grad = zeros(x.data))
2017-09-07 03:09:32 +00:00
end
return
2017-09-07 01:21:35 +00:00
end
2017-12-15 02:29:14 +00:00
back_(f, y, args...) = back(f, args...)
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
back_(::Call{Void}, y, Δ) = nothing
2017-09-07 03:09:32 +00:00
function back(x::TrackedArray, Δ)
2017-10-18 21:54:58 +00:00
ref = x.ref -= 1
if isdefined(x, :grad)
x.grad .+= Δ
2017-12-15 02:29:14 +00:00
ref == 0 && back_(x.f, x.data, x.grad)
2017-09-07 03:09:32 +00:00
else
2017-12-15 02:29:14 +00:00
ref == 0 && back_(x.f, x.data, Δ)
2017-09-07 03:09:32 +00:00
end
return
end
2017-09-07 01:21:35 +00:00
2017-09-07 03:09:32 +00:00
macro back(x, Δ)
2017-09-07 01:21:35 +00:00
quote
x = $(esc(x))
2017-09-07 03:09:32 +00:00
istracked(x) && back(x, $(esc(Δ)))
2017-09-07 01:21:35 +00:00
end
end
2017-09-07 03:09:32 +00:00
# Interface methods
function back!(x::TrackedArray, Δ)
scan(x)
back(x, Δ)
end
back!(x::TrackedScalar) = back!(x, 1)