2018-02-07 22:20:44 +00:00
|
|
|
init_grad(x) = zero(x)
|
2018-02-12 12:31:15 +00:00
|
|
|
zero_grad!(x) = zero(x)
|
|
|
|
zero_grad!(x::AbstractArray) = (x .= 0)
|
2018-02-07 22:20:44 +00:00
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
scan(c::Call) = foreach(scan, c.args)
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
function scan(x::Tracked)
|
2018-02-12 12:31:15 +00:00
|
|
|
x.isleaf && return
|
2017-10-18 21:54:58 +00:00
|
|
|
ref = x.ref += 1
|
2017-09-07 03:09:32 +00:00
|
|
|
if ref == 1
|
|
|
|
scan(x.f)
|
2018-02-12 12:31:15 +00:00
|
|
|
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
|
2017-09-07 03:09:32 +00:00
|
|
|
end
|
|
|
|
return
|
2017-09-07 01:21:35 +00:00
|
|
|
end
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
function scan(x)
|
|
|
|
istracked(x) && scan(tracker(x))
|
|
|
|
return
|
|
|
|
end
|
|
|
|
|
2018-10-26 15:57:19 +00:00
|
|
|
function back_(c::Call, Δ, once)
|
2018-07-06 10:28:18 +00:00
|
|
|
Δs = c.func(Δ)
|
2018-07-09 12:39:10 +00:00
|
|
|
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
2018-07-06 10:28:18 +00:00
|
|
|
error("Gradient is not a tuple of length $(length(c.args))")
|
2018-10-26 15:57:19 +00:00
|
|
|
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
|
2018-07-06 10:28:18 +00:00
|
|
|
end
|
|
|
|
|
2018-10-26 15:57:19 +00:00
|
|
|
back_(::Call{Nothing}, _, _) = nothing
|
|
|
|
back_(::Call{Missing}, _, _) = error("`back!` was already used")
|
2017-09-07 03:09:32 +00:00
|
|
|
|
2018-02-07 22:20:44 +00:00
|
|
|
accum!(x, Δ) = x .+ Δ
|
|
|
|
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
2018-02-07 20:39:36 +00:00
|
|
|
|
2018-10-26 15:57:19 +00:00
|
|
|
function back(x::Tracked, Δ, once)
|
2018-03-21 11:25:47 +00:00
|
|
|
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
2017-10-18 21:54:58 +00:00
|
|
|
ref = x.ref -= 1
|
2018-10-26 15:57:19 +00:00
|
|
|
grad = if isdefined(x, :grad)
|
|
|
|
x.grad = accum!(x.grad, Δ)
|
|
|
|
elseif ref > 0
|
|
|
|
x.grad = Δ
|
2017-09-07 03:09:32 +00:00
|
|
|
else
|
2018-10-26 15:57:19 +00:00
|
|
|
Δ
|
|
|
|
end
|
|
|
|
if ref == 0
|
|
|
|
back_(x.f, grad, once)
|
|
|
|
once && !x.isleaf && (x.f = Call(missing, ()))
|
2017-09-07 03:09:32 +00:00
|
|
|
end
|
|
|
|
return
|
|
|
|
end
|
2017-09-07 01:21:35 +00:00
|
|
|
|
2018-10-26 15:57:19 +00:00
|
|
|
back(::Nothing, _, _) = return
|
2018-02-07 17:43:25 +00:00
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
# Interface methods
|
|
|
|
|
2017-12-15 16:17:45 +00:00
|
|
|
# TODO: if an error occurs in `back` the refcounts will be broken
|
|
|
|
# and `back` will silently fail to update.
|
2018-10-26 15:57:19 +00:00
|
|
|
# (but only if you re-use intermediate values between passes)
|
2018-07-09 15:57:44 +00:00
|
|
|
# Refcounts are also probably not safe in some situations (e.g. back called
|
|
|
|
# from within a backpropagator)
|
2017-12-15 16:17:45 +00:00
|
|
|
|
2018-10-26 15:57:19 +00:00
|
|
|
function back!(x, Δ; once = true)
|
2018-07-09 18:44:14 +00:00
|
|
|
istracked(x) || return
|
2017-09-07 03:09:32 +00:00
|
|
|
scan(x)
|
2018-10-26 15:57:19 +00:00
|
|
|
back(tracker(x), Δ, once)
|
2018-07-09 18:44:14 +00:00
|
|
|
return
|
2017-09-07 03:09:32 +00:00
|
|
|
end
|
|
|
|
|
2018-07-09 15:57:44 +00:00
|
|
|
# Out-of-place gradients
|
|
|
|
|
|
|
|
struct Params
|
|
|
|
params::IdSet
|
|
|
|
Params(xs) = new(IdSet(xs))
|
|
|
|
end
|
|
|
|
|
2018-08-11 09:51:07 +00:00
|
|
|
@forward Params.params Base.iterate, Base.length
|
2018-07-09 15:57:44 +00:00
|
|
|
|
2018-07-11 14:31:22 +00:00
|
|
|
function Base.show(io::IO, ps::Params)
|
|
|
|
print(io, "Params([")
|
|
|
|
join(io, ps.params, ", ")
|
|
|
|
print(io, "])")
|
|
|
|
end
|
|
|
|
|
2018-07-09 15:57:44 +00:00
|
|
|
struct Grads
|
2018-07-12 19:56:51 +00:00
|
|
|
grads::IdDict{Any,Any}
|
2018-07-09 15:57:44 +00:00
|
|
|
end
|
|
|
|
|
2018-07-11 14:31:22 +00:00
|
|
|
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
|
|
|
|
2018-07-12 19:56:51 +00:00
|
|
|
Grads() = Grads(IdDict())
|
2018-07-09 15:57:44 +00:00
|
|
|
|
2018-08-11 09:51:07 +00:00
|
|
|
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
|
|
|
|
|
2018-07-12 19:56:51 +00:00
|
|
|
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
2018-07-10 08:03:09 +00:00
|
|
|
|
2018-07-09 15:57:44 +00:00
|
|
|
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
|
|
|
function Base.getindex(g::Grads, x)
|
|
|
|
istracked(x) || error("Object not tracked: $x")
|
|
|
|
g[tracker(x)]
|
|
|
|
end
|
|
|
|
|
|
|
|
|
2018-07-18 13:39:20 +00:00
|
|
|
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
2018-07-09 15:57:44 +00:00
|
|
|
|
|
|
|
function back_(g::Grads, c::Call, Δ)
|
|
|
|
Δs = c.func(Δ)
|
|
|
|
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
|
|
|
error("Gradient is not a tuple of length $(length(c.args))")
|
2018-07-09 18:44:14 +00:00
|
|
|
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
2018-07-09 15:57:44 +00:00
|
|
|
end
|
|
|
|
|
2018-06-12 17:09:18 +00:00
|
|
|
back_(g::Grads, ::Call{Nothing}, Δ) = nothing
|
2018-07-09 15:57:44 +00:00
|
|
|
|
|
|
|
function back(g::Grads, x::Tracked, Δ)
|
|
|
|
x.isleaf && (accum!(g, x, Δ); return)
|
|
|
|
ref = x.ref -= 1
|
|
|
|
if ref > 0 || haskey(g, x)
|
|
|
|
accum!(g, x, Δ)
|
|
|
|
ref == 0 && back_(g, x.f, g[x])
|
|
|
|
else
|
|
|
|
ref == 0 && back_(g, x.f, Δ)
|
|
|
|
end
|
|
|
|
return
|
|
|
|
end
|
|
|
|
|
2018-06-12 17:09:18 +00:00
|
|
|
back(::Grads, ::Nothing, _) = return
|
2018-07-09 15:57:44 +00:00
|
|
|
|
|
|
|
function forward(f, ps::Params)
|
|
|
|
y = f()
|
|
|
|
y, function (Δ)
|
2018-07-10 08:03:09 +00:00
|
|
|
g = Grads(ps)
|
2018-07-09 15:57:44 +00:00
|
|
|
if istracked(y)
|
|
|
|
scan(y)
|
2018-07-09 18:44:14 +00:00
|
|
|
back(g, tracker(y), Δ)
|
2018-07-09 15:57:44 +00:00
|
|
|
end
|
|
|
|
return g
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
function forward(f, args...)
|
|
|
|
args = param.(args)
|
|
|
|
y, back = forward(() -> f(args...), Params(args))
|
2018-08-11 13:27:56 +00:00
|
|
|
y, Δ -> getindex.(Ref(back(Δ)), args)
|
2018-07-09 15:57:44 +00:00
|
|
|
end
|
|
|
|
|
|
|
|
function losscheck(x)
|
|
|
|
x isa Real || error("Function output is not scalar")
|
|
|
|
isinf(x) && error("Loss is infinite")
|
|
|
|
isnan(x) && error("Loss is NaN")
|
|
|
|
end
|
|
|
|
|
|
|
|
function gradient(f, args...)
|
|
|
|
y, back = forward(f, args...)
|
|
|
|
losscheck(y)
|
|
|
|
return back(1)
|
|
|
|
end
|
|
|
|
|
|
|
|
derivative(f, x) = gradient(f, x)[1]
|
2018-07-30 19:08:44 +00:00
|
|
|
|
|
|
|
# Non-nesting versions
|
|
|
|
|
|
|
|
function gradient_(f, xs...)
|
|
|
|
xs = param.(xs)
|
|
|
|
l = f(xs...)
|
|
|
|
losscheck(l)
|
|
|
|
back!(l)
|
|
|
|
grad.(xs)
|
|
|
|
end
|