iterative backwards pass
This commit is contained in:
parent
9781f063aa
commit
66580f7f20
|
@ -30,13 +30,12 @@ a::Call == b::Call = a.func == b.func && a.args == b.args
|
|||
@inline (c::Call)() = c.func(data.(c.args)...)
|
||||
|
||||
mutable struct Tracked{T}
|
||||
ref::UInt32
|
||||
f::Call
|
||||
isleaf::Bool
|
||||
grad::T
|
||||
Tracked{T}(f::Call) where T = new(0, f, false)
|
||||
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
|
||||
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad)
|
||||
Tracked{T}(f::Call) where T = new(f, false)
|
||||
Tracked{T}(f::Call, grad::T) where T = new(f, false, grad)
|
||||
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(f, true, grad)
|
||||
end
|
||||
|
||||
istracked(x::Tracked) = true
|
||||
|
|
|
@ -2,66 +2,64 @@ init_grad(x) = zero(x)
|
|||
zero_grad!(x) = zero(x)
|
||||
zero_grad!(x::AbstractArray) = (x .= 0)
|
||||
|
||||
scan(c::Call) = foreach(scan, c.args)
|
||||
|
||||
function scan(x::Tracked)
|
||||
x.isleaf && return
|
||||
ref = x.ref += 1
|
||||
if ref == 1
|
||||
scan(x.f)
|
||||
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
|
||||
function _walk(queue, seen, c::Call)
|
||||
foreach(c.args) do x
|
||||
x === nothing && return
|
||||
id = objectid(x)
|
||||
if id ∉ seen
|
||||
push!(seen, id)
|
||||
pushfirst!(queue, x)
|
||||
end
|
||||
return
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
function scan(x)
|
||||
istracked(x) && scan(tracker(x))
|
||||
return
|
||||
function walk(f, x::Tracked; once = true)
|
||||
queue = Tracked[x]
|
||||
seen = Set{UInt64}()
|
||||
while !isempty(queue)
|
||||
x = pop!(queue)
|
||||
f(x)
|
||||
_walk(queue, seen, x.f)
|
||||
once && !x.isleaf && (x.f = Call(missing, ()))
|
||||
end
|
||||
end
|
||||
|
||||
function back_(c::Call, Δ, once)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
|
||||
end
|
||||
|
||||
back_(::Call{Nothing}, Δ, once) = nothing
|
||||
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
|
||||
|
||||
accum!(x, Δ) = x .+ Δ
|
||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
|
||||
function back(x::Tracked, Δ, once)
|
||||
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
grad = if isdefined(x, :grad)
|
||||
function _back(x::Tracked, Δ)
|
||||
if isdefined(x, :grad)
|
||||
x.grad = accum!(x.grad, Δ)
|
||||
elseif ref > 0
|
||||
x.grad = Δ
|
||||
else
|
||||
Δ
|
||||
end
|
||||
if ref == 0
|
||||
back_(x.f, grad, once)
|
||||
once && !x.isleaf && (x.f = Call(missing, ()))
|
||||
x.grad = Δ
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
back(::Nothing, Δ, once) = return
|
||||
_back(::Nothing, Δ) = return
|
||||
|
||||
function _back(c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach((x, d) -> _back(x, d), c.args, data.(Δs))
|
||||
end
|
||||
|
||||
_back(::Call{Nothing}, Δ) = nothing
|
||||
_back(::Call{Missing}, Δ) = error("`back!` was already used")
|
||||
|
||||
function back(x::Tracked, Δ, once)
|
||||
_back(x, Δ)
|
||||
walk(x, once = once) do x
|
||||
_back(x.f, x.grad)
|
||||
end
|
||||
end
|
||||
|
||||
# Interface methods
|
||||
|
||||
# TODO: if an error occurs in `back` the refcounts will be broken
|
||||
# and `back` will silently fail to update.
|
||||
# (but only if you re-use intermediate values between passes)
|
||||
# Refcounts are also probably not safe in some situations (e.g. back called
|
||||
# from within a backpropagator)
|
||||
|
||||
function back!(x, Δ; once = true)
|
||||
istracked(x) || return
|
||||
scan(x)
|
||||
back(tracker(x), Δ, once)
|
||||
return
|
||||
end
|
||||
|
@ -124,35 +122,35 @@ end
|
|||
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
||||
|
||||
function back_(g::Grads, c::Call, Δ)
|
||||
function _back(g::Grads, c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
||||
foreach((x, Δ) -> _back(g, x, Δ), c.args, Δs)
|
||||
end
|
||||
|
||||
back_(g::Grads, ::Call{Nothing}, Δ) = nothing
|
||||
_back(g::Grads, ::Call{Nothing}, Δ) = nothing
|
||||
|
||||
function back(g::Grads, x::Tracked, Δ)
|
||||
function _back(g::Grads, x::Tracked, Δ)
|
||||
x.isleaf && (accum!(g, x, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
if ref > 0 || haskey(g, x)
|
||||
accum!(g, x, Δ)
|
||||
ref == 0 && back_(g, x.f, g[x])
|
||||
else
|
||||
ref == 0 && back_(g, x.f, Δ)
|
||||
end
|
||||
accum!(g, x, Δ)
|
||||
return
|
||||
end
|
||||
|
||||
back(::Grads, ::Nothing, _) = return
|
||||
_back(g::Grads, ::Nothing, Δ) = return
|
||||
|
||||
function back(g::Grads, x::Tracked, Δ)
|
||||
_back(g, x, Δ)
|
||||
walk(x, once = false) do x
|
||||
_back(g, x.f, g[x])
|
||||
end
|
||||
end
|
||||
|
||||
function forward(f, ps::Params)
|
||||
y = f()
|
||||
y, function (Δ)
|
||||
g = Grads(ps)
|
||||
if istracked(y)
|
||||
scan(y)
|
||||
back(g, tracker(y), Δ)
|
||||
end
|
||||
return g
|
||||
|
|
Loading…
Reference in New Issue