Merge pull request #465 from FluxML/mji/once

Destroy AD graph when doing in-place gradients
This commit is contained in:
Mike J Innes 2018-10-31 14:14:38 +00:00 committed by GitHub
commit 70283e1971
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 24 deletions

View File

@ -19,47 +19,50 @@ function scan(x)
return
end
function back_(c::Call, Δ)
function back_(c::Call, Δ, once)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach(back, c.args, data.(Δs))
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
end
back_(::Call{Nothing}, Δ) = nothing
back_(::Call{Nothing}, Δ, once) = nothing
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ)
function back(x::Tracked, Δ)
function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1
if ref > 0 || isdefined(x, :grad)
if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ)
else
x.grad = Δ
end
ref == 0 && back_(x.f, x.grad)
grad = if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ)
elseif ref > 0
x.grad = Δ
else
ref == 0 && back_(x.f, Δ)
Δ
end
if ref == 0
back_(x.f, grad, once)
once && !x.isleaf && (x.f = Call(missing, ()))
end
return
end
back(::Nothing, _) = return
back(::Nothing, Δ, once) = return
# Interface methods
# TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update.
# (but only if you re-use intermediate values between passes)
# Refcounts are also probably not safe in some situations (e.g. back called
# from within a backpropagator)
function back!(x, Δ)
function back!(x, Δ; once = true)
istracked(x) || return
scan(x)
back(tracker(x), Δ)
back(tracker(x), Δ, once)
return
end
@ -91,12 +94,12 @@ Grads() = Grads(IdDict())
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x")
g[tracker(x)]
end
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ)

View File

@ -10,10 +10,10 @@ tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
function back!(x::TrackedReal)
function back!(x::TrackedReal; once = true)
isinf(x) && error("Loss is Inf")
isnan(x) && error("Loss is NaN")
return back!(x, 1)
return back!(x, 1, once = once)
end
function Base.show(io::IO, x::TrackedReal)
@ -123,8 +123,8 @@ function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1])
end
function back_(c::Call{typeof(collect)}, Δ)
foreach(back, c.args[1], data(Δ))
function back_(c::Call{typeof(collect)}, Δ, once)
foreach((x, d) -> back(x, d, once), c.args[1], data(Δ))
end
function back_(g::Grads, c::Call{typeof(collect)}, Δ)

View File

@ -147,9 +147,9 @@ function jacobian(m,x)
n = length(x)
J = Matrix{eltype(x)}(undef,n,k)
for i = 1:k
Flux.back!(y[i]) # Populate gradient accumulator
Flux.back!(y[i], once = false) # Populate gradient accumulator
J[:,i] = xp.grad
xp.grad .*= 0 # Reset gradient accumulator
xp.grad .= 0 # Reset gradient accumulator
end
J'
end

View File

@ -237,10 +237,10 @@ end
@testset "Intermediates" begin
x = param([1])
l = sum((x .+ x).^2)
Flux.back!(l)
Flux.back!(l, once = false)
@test x.grad == [8]
x.grad .= 0
Flux.back!(l)
Flux.back!(l, once = false)
@test x.grad == [8]
end