Flux.jl/src/tracker/back.jl

209 lines
4.5 KiB
Julia
Raw Normal View History

2018-02-07 22:20:44 +00:00
init_grad(x) = zero(x)
2018-02-12 12:31:15 +00:00
zero_grad!(x) = zero(x)
zero_grad!(x::AbstractArray) = (x .= 0)
2018-02-07 22:20:44 +00:00
2017-09-07 03:09:32 +00:00
scan(c::Call) = foreach(scan, c.args)
function scan(x::Tracked)
2018-02-12 12:31:15 +00:00
x.isleaf && return
2017-10-18 21:54:58 +00:00
ref = x.ref += 1
2017-09-07 03:09:32 +00:00
if ref == 1
scan(x.f)
2018-02-12 12:31:15 +00:00
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
2017-09-07 03:09:32 +00:00
end
return
2017-09-07 01:21:35 +00:00
end
function scan(x)
istracked(x) && scan(tracker(x))
return
end
function back_(c::Call, Δ, once)
2018-07-06 10:28:18 +00:00
Δs = c.func(Δ)
2018-07-09 12:39:10 +00:00
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
2018-07-06 10:28:18 +00:00
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
2018-07-06 10:28:18 +00:00
end
2018-10-27 11:23:14 +00:00
back_(::Call{Nothing}, Δ, once) = nothing
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
2017-09-07 03:09:32 +00:00
2018-02-07 22:20:44 +00:00
accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ)
2018-02-07 20:39:36 +00:00
function back(x::Tracked, Δ, once)
2018-03-21 11:25:47 +00:00
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
2017-10-18 21:54:58 +00:00
ref = x.ref -= 1
grad = if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ)
elseif ref > 0
x.grad = Δ
2017-09-07 03:09:32 +00:00
else
Δ
end
if ref == 0
back_(x.f, grad, once)
once && !x.isleaf && (x.f = Call(missing, ()))
2017-09-07 03:09:32 +00:00
end
return
end
2017-09-07 01:21:35 +00:00
2018-10-27 11:23:14 +00:00
back(::Nothing, Δ, once) = return
2017-09-07 03:09:32 +00:00
# Interface methods
2017-12-15 16:17:45 +00:00
# TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update.
# (but only if you re-use intermediate values between passes)
2018-07-09 15:57:44 +00:00
# Refcounts are also probably not safe in some situations (e.g. back called
# from within a backpropagator)
2017-12-15 16:17:45 +00:00
function back!(x, Δ; once = true)
2018-07-09 18:44:14 +00:00
istracked(x) || return
2017-09-07 03:09:32 +00:00
scan(x)
back(tracker(x), Δ, once)
2018-07-09 18:44:14 +00:00
return
2017-09-07 03:09:32 +00:00
end
2018-11-12 23:39:25 +00:00
function gradient_(f, xs...)
2019-01-15 15:48:38 +00:00
xs = param.(data.(xs))
2018-11-12 23:39:25 +00:00
l = f(xs...)
losscheck(l)
back!(l)
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
grad.(xs))
end
2018-07-09 15:57:44 +00:00
# Out-of-place gradients
struct Params
2018-10-31 15:50:08 +00:00
order::Vector{Any}
params::IdSet{Any}
Params() = new([], IdSet())
2018-07-09 15:57:44 +00:00
end
2018-10-31 15:50:08 +00:00
@forward Params.order Base.iterate, Base.length
function Base.push!(ps::Params, x)
if !(x in ps.params)
push!(ps.order, x)
push!(ps.params, x)
end
return ps
end
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
Params(xs) = push!(Params(), xs...)
2018-07-09 15:57:44 +00:00
2018-07-11 14:31:22 +00:00
function Base.show(io::IO, ps::Params)
print(io, "Params([")
2018-10-31 15:50:08 +00:00
join(io, ps.order, ", ")
2018-07-11 14:31:22 +00:00
print(io, "])")
end
2018-07-09 15:57:44 +00:00
struct Grads
2018-07-12 19:56:51 +00:00
grads::IdDict{Any,Any}
2018-07-09 15:57:44 +00:00
end
2018-07-11 14:31:22 +00:00
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
2018-07-12 19:56:51 +00:00
Grads() = Grads(IdDict())
2018-07-09 15:57:44 +00:00
2018-08-11 09:51:07 +00:00
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
2018-07-12 19:56:51 +00:00
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
2018-07-10 08:03:09 +00:00
2018-07-09 15:57:44 +00:00
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
2018-10-27 11:23:14 +00:00
2018-07-09 15:57:44 +00:00
function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x")
g[tracker(x)]
end
2018-07-18 13:39:20 +00:00
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
2018-07-09 15:57:44 +00:00
function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
2018-07-09 18:44:14 +00:00
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
2018-07-09 15:57:44 +00:00
end
2018-06-12 17:09:18 +00:00
back_(g::Grads, ::Call{Nothing}, Δ) = nothing
2018-07-09 15:57:44 +00:00
function back(g::Grads, x::Tracked, Δ)
x.isleaf && (accum!(g, x, Δ); return)
ref = x.ref -= 1
if ref > 0 || haskey(g, x)
accum!(g, x, Δ)
ref == 0 && back_(g, x.f, g[x])
else
ref == 0 && back_(g, x.f, Δ)
end
return
end
2018-06-12 17:09:18 +00:00
back(::Grads, ::Nothing, _) = return
2018-07-09 15:57:44 +00:00
function forward(f, ps::Params)
y = f()
y, function (Δ)
2018-07-10 08:03:09 +00:00
g = Grads(ps)
2018-07-09 15:57:44 +00:00
if istracked(y)
scan(y)
2018-07-09 18:44:14 +00:00
back(g, tracker(y), Δ)
2018-07-09 15:57:44 +00:00
end
return g
end
end
function forward(f, args...)
args = param.(args)
y, back = forward(() -> f(args...), Params(args))
y, Δ -> getindex.(Ref(back(Δ)), args)
2018-07-09 15:57:44 +00:00
end
function losscheck(x)
x isa Real || error("Function output is not scalar")
isinf(x) && error("Loss is infinite")
isnan(x) && error("Loss is NaN")
end
2018-11-12 23:39:25 +00:00
function gradient_nested(f, args...)
2018-07-09 15:57:44 +00:00
y, back = forward(f, args...)
losscheck(y)
return back(1)
end
2018-11-12 23:39:25 +00:00
gradient(f, xs...; nest = false) =
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
2019-01-15 15:48:38 +00:00
gradient(f, ps::Params) = gradient_nested(f, ps)
2019-01-29 08:37:30 +00:00
# Jacobians and Hessians
import ..Flux
"""
J = jacobian(m,x)
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
"""
function jacobian(m,x)
xp = param(x)
y = m(xp)
k = length(y)
n = length(x)
J = Matrix{eltype(x)}(undef,k,n)
for i = 1:k
Flux.back!(y[i], once = false) # Populate gradient accumulator
J[i,:] = xp.grad
xp.grad .= 0 # Reset gradient accumulator
end
J
end
hessian(f, x) = jacobian(x -> gradient(f, x, nest=true)[1], x)