iterative backwards pass

This commit is contained in:
Mike J Innes 2018-12-19 18:37:17 +00:00
parent 9781f063aa
commit 66580f7f20
2 changed files with 55 additions and 58 deletions

View File

@ -30,13 +30,12 @@ a::Call == b::Call = a.func == b.func && a.args == b.args
@inline (c::Call)() = c.func(data.(c.args)...)
mutable struct Tracked{T}
ref::UInt32
f::Call
isleaf::Bool
grad::T
Tracked{T}(f::Call) where T = new(0, f, false)
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad)
Tracked{T}(f::Call) where T = new(f, false)
Tracked{T}(f::Call, grad::T) where T = new(f, false, grad)
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(f, true, grad)
end
istracked(x::Tracked) = true

View File

@ -2,66 +2,64 @@ init_grad(x) = zero(x)
zero_grad!(x) = zero(x)
zero_grad!(x::AbstractArray) = (x .= 0)
scan(c::Call) = foreach(scan, c.args)
function scan(x::Tracked)
x.isleaf && return
ref = x.ref += 1
if ref == 1
scan(x.f)
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
function _walk(queue, seen, c::Call)
foreach(c.args) do x
x === nothing && return
id = objectid(x)
if id seen
push!(seen, id)
pushfirst!(queue, x)
end
return
end
return
end
function scan(x)
istracked(x) && scan(tracker(x))
return
function walk(f, x::Tracked; once = true)
queue = Tracked[x]
seen = Set{UInt64}()
while !isempty(queue)
x = pop!(queue)
f(x)
_walk(queue, seen, x.f)
once && !x.isleaf && (x.f = Call(missing, ()))
end
end
function back_(c::Call, Δ, once)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
end
back_(::Call{Nothing}, Δ, once) = nothing
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ)
function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1
grad = if isdefined(x, :grad)
function _back(x::Tracked, Δ)
if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ)
elseif ref > 0
x.grad = Δ
else
Δ
end
if ref == 0
back_(x.f, grad, once)
once && !x.isleaf && (x.f = Call(missing, ()))
x.grad = Δ
end
return
end
back(::Nothing, Δ, once) = return
_back(::Nothing, Δ) = return
function _back(c::Call, Δ)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, d) -> _back(x, d), c.args, data.(Δs))
end
_back(::Call{Nothing}, Δ) = nothing
_back(::Call{Missing}, Δ) = error("`back!` was already used")
function back(x::Tracked, Δ, once)
_back(x, Δ)
walk(x, once = once) do x
_back(x.f, x.grad)
end
end
# Interface methods
# TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update.
# (but only if you re-use intermediate values between passes)
# Refcounts are also probably not safe in some situations (e.g. back called
# from within a backpropagator)
function back!(x, Δ; once = true)
istracked(x) || return
scan(x)
back(tracker(x), Δ, once)
return
end
@ -124,35 +122,35 @@ end
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ)
function _back(g::Grads, c::Call, Δ)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
foreach((x, Δ) -> _back(g, x, Δ), c.args, Δs)
end
back_(g::Grads, ::Call{Nothing}, Δ) = nothing
_back(g::Grads, ::Call{Nothing}, Δ) = nothing
function back(g::Grads, x::Tracked, Δ)
function _back(g::Grads, x::Tracked, Δ)
x.isleaf && (accum!(g, x, Δ); return)
ref = x.ref -= 1
if ref > 0 || haskey(g, x)
accum!(g, x, Δ)
ref == 0 && back_(g, x.f, g[x])
else
ref == 0 && back_(g, x.f, Δ)
end
accum!(g, x, Δ)
return
end
back(::Grads, ::Nothing, _) = return
_back(g::Grads, ::Nothing, Δ) = return
function back(g::Grads, x::Tracked, Δ)
_back(g, x, Δ)
walk(x, once = false) do x
_back(g, x.f, g[x])
end
end
function forward(f, ps::Params)
y = f()
y, function (Δ)
g = Grads(ps)
if istracked(y)
scan(y)
back(g, tracker(y), Δ)
end
return g