|
|
|
@ -2,66 +2,64 @@ init_grad(x) = zero(x)
|
|
|
|
|
zero_grad!(x) = zero(x)
|
|
|
|
|
zero_grad!(x::AbstractArray) = (x .= 0)
|
|
|
|
|
|
|
|
|
|
scan(c::Call) = foreach(scan, c.args)
|
|
|
|
|
|
|
|
|
|
function scan(x::Tracked)
|
|
|
|
|
x.isleaf && return
|
|
|
|
|
ref = x.ref += 1
|
|
|
|
|
if ref == 1
|
|
|
|
|
scan(x.f)
|
|
|
|
|
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
|
|
|
|
|
function _walk(queue, seen, c::Call)
|
|
|
|
|
foreach(c.args) do x
|
|
|
|
|
x === nothing && return
|
|
|
|
|
id = objectid(x)
|
|
|
|
|
if id ∉ seen
|
|
|
|
|
push!(seen, id)
|
|
|
|
|
pushfirst!(queue, x)
|
|
|
|
|
end
|
|
|
|
|
return
|
|
|
|
|
end
|
|
|
|
|
return
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function scan(x)
|
|
|
|
|
istracked(x) && scan(tracker(x))
|
|
|
|
|
return
|
|
|
|
|
function walk(f, x::Tracked; once = true)
|
|
|
|
|
queue = Tracked[x]
|
|
|
|
|
seen = Set{UInt64}()
|
|
|
|
|
while !isempty(queue)
|
|
|
|
|
x = pop!(queue)
|
|
|
|
|
f(x)
|
|
|
|
|
_walk(queue, seen, x.f)
|
|
|
|
|
once && !x.isleaf && (x.f = Call(missing, ()))
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function back_(c::Call, Δ, once)
|
|
|
|
|
Δs = c.func(Δ)
|
|
|
|
|
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
|
|
|
|
error("Gradient is not a tuple of length $(length(c.args))")
|
|
|
|
|
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
back_(::Call{Nothing}, Δ, once) = nothing
|
|
|
|
|
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
|
|
|
|
|
|
|
|
|
|
accum!(x, Δ) = x .+ Δ
|
|
|
|
|
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
|
|
|
|
|
|
|
|
|
function back(x::Tracked, Δ, once)
|
|
|
|
|
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
|
|
|
|
ref = x.ref -= 1
|
|
|
|
|
grad = if isdefined(x, :grad)
|
|
|
|
|
function _back(x::Tracked, Δ)
|
|
|
|
|
if isdefined(x, :grad)
|
|
|
|
|
x.grad = accum!(x.grad, Δ)
|
|
|
|
|
elseif ref > 0
|
|
|
|
|
x.grad = Δ
|
|
|
|
|
else
|
|
|
|
|
Δ
|
|
|
|
|
end
|
|
|
|
|
if ref == 0
|
|
|
|
|
back_(x.f, grad, once)
|
|
|
|
|
once && !x.isleaf && (x.f = Call(missing, ()))
|
|
|
|
|
x.grad = Δ
|
|
|
|
|
end
|
|
|
|
|
return
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
back(::Nothing, Δ, once) = return
|
|
|
|
|
_back(::Nothing, Δ) = return
|
|
|
|
|
|
|
|
|
|
function _back(c::Call, Δ)
|
|
|
|
|
Δs = c.func(Δ)
|
|
|
|
|
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
|
|
|
|
error("Gradient is not a tuple of length $(length(c.args))")
|
|
|
|
|
foreach((x, d) -> _back(x, d), c.args, data.(Δs))
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
_back(::Call{Nothing}, Δ) = nothing
|
|
|
|
|
_back(::Call{Missing}, Δ) = error("`back!` was already used")
|
|
|
|
|
|
|
|
|
|
function back(x::Tracked, Δ, once)
|
|
|
|
|
_back(x, Δ)
|
|
|
|
|
walk(x, once = once) do x
|
|
|
|
|
_back(x.f, x.grad)
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
# Interface methods
|
|
|
|
|
|
|
|
|
|
# TODO: if an error occurs in `back` the refcounts will be broken
|
|
|
|
|
# and `back` will silently fail to update.
|
|
|
|
|
# (but only if you re-use intermediate values between passes)
|
|
|
|
|
# Refcounts are also probably not safe in some situations (e.g. back called
|
|
|
|
|
# from within a backpropagator)
|
|
|
|
|
|
|
|
|
|
function back!(x, Δ; once = true)
|
|
|
|
|
istracked(x) || return
|
|
|
|
|
scan(x)
|
|
|
|
|
back(tracker(x), Δ, once)
|
|
|
|
|
return
|
|
|
|
|
end
|
|
|
|
@ -124,35 +122,35 @@ end
|
|
|
|
|
|
|
|
|
|
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
|
|
|
|
|
|
|
|
|
function back_(g::Grads, c::Call, Δ)
|
|
|
|
|
function _back(g::Grads, c::Call, Δ)
|
|
|
|
|
Δs = c.func(Δ)
|
|
|
|
|
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
|
|
|
|
error("Gradient is not a tuple of length $(length(c.args))")
|
|
|
|
|
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
|
|
|
|
foreach((x, Δ) -> _back(g, x, Δ), c.args, Δs)
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
back_(g::Grads, ::Call{Nothing}, Δ) = nothing
|
|
|
|
|
_back(g::Grads, ::Call{Nothing}, Δ) = nothing
|
|
|
|
|
|
|
|
|
|
function back(g::Grads, x::Tracked, Δ)
|
|
|
|
|
function _back(g::Grads, x::Tracked, Δ)
|
|
|
|
|
x.isleaf && (accum!(g, x, Δ); return)
|
|
|
|
|
ref = x.ref -= 1
|
|
|
|
|
if ref > 0 || haskey(g, x)
|
|
|
|
|
accum!(g, x, Δ)
|
|
|
|
|
ref == 0 && back_(g, x.f, g[x])
|
|
|
|
|
else
|
|
|
|
|
ref == 0 && back_(g, x.f, Δ)
|
|
|
|
|
end
|
|
|
|
|
accum!(g, x, Δ)
|
|
|
|
|
return
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
back(::Grads, ::Nothing, _) = return
|
|
|
|
|
_back(g::Grads, ::Nothing, Δ) = return
|
|
|
|
|
|
|
|
|
|
function back(g::Grads, x::Tracked, Δ)
|
|
|
|
|
_back(g, x, Δ)
|
|
|
|
|
walk(x, once = false) do x
|
|
|
|
|
_back(g, x.f, g[x])
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function forward(f, ps::Params)
|
|
|
|
|
y = f()
|
|
|
|
|
y, function (Δ)
|
|
|
|
|
g = Grads(ps)
|
|
|
|
|
if istracked(y)
|
|
|
|
|
scan(y)
|
|
|
|
|
back(g, tracker(y), Δ)
|
|
|
|
|
end
|
|
|
|
|
return g
|
|
|
|
|