Flux.jl/src/tracker/back.jl

57 lines
1.1 KiB
Julia
Raw Normal View History

2017-09-07 03:09:32 +00:00
scan(c::Call) = foreach(scan, c.args)
function scan(x::Tracked)
2017-10-18 21:54:58 +00:00
ref = x.ref += 1
2017-09-07 03:09:32 +00:00
if ref == 1
scan(x.f)
else
2017-10-18 21:54:58 +00:00
isdefined(x, :grad) || (x.grad = zeros(x.data))
2017-09-07 03:09:32 +00:00
end
return
2017-09-07 01:21:35 +00:00
end
function scan(x)
istracked(x) && scan(tracker(x))
return
end
2017-12-15 02:29:14 +00:00
back_(f, y, args...) = back(f, args...)
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
back_(::Call{Void}, y, Δ) = nothing
2017-09-07 03:09:32 +00:00
2018-02-07 20:39:36 +00:00
accum!(x::Tracked, Δ) = (x.grad += Δ)
accum!(x::Tracked{<:AbstractArray}, Δ) = (x.grad .+= Δ)
function back(x::Tracked, Δ)
2017-10-18 21:54:58 +00:00
ref = x.ref -= 1
if isdefined(x, :grad)
2018-02-07 20:39:36 +00:00
accum!(x, Δ)
2017-12-15 02:29:14 +00:00
ref == 0 && back_(x.f, x.data, x.grad)
2017-09-07 03:09:32 +00:00
else
2017-12-15 02:29:14 +00:00
ref == 0 && back_(x.f, x.data, Δ)
2017-09-07 03:09:32 +00:00
end
return
end
2017-09-07 01:21:35 +00:00
back(x, Δ) = back(tracker(x), Δ)
2018-02-07 20:39:36 +00:00
back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
2017-09-07 03:09:32 +00:00
macro back(x, Δ)
2017-09-07 01:21:35 +00:00
quote
x = $(esc(x))
2017-09-07 03:09:32 +00:00
istracked(x) && back(x, $(esc(Δ)))
2017-09-07 01:21:35 +00:00
end
end
2017-09-07 03:09:32 +00:00
# Interface methods
2017-12-15 16:17:45 +00:00
# TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update.
function back!(x::Tracked, Δ)
2017-09-07 03:09:32 +00:00
scan(x)
back(x, Δ)
end
back!(x, Δ) = back!(tracker(x), Δ)