destroy AD graph when doing in-place gradients

This commit is contained in:
Mike J Innes 2018-10-26 16:57:19 +01:00
parent 44ccdb7ca9
commit c21d768b7c
4 changed files with 26 additions and 23 deletions

View File

@ -19,47 +19,50 @@ function scan(x)
return
end
function back_(c::Call, Δ)
function back_(c::Call, Δ, once)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach(back, c.args, data.(Δs))
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
end
back_(::Call{Nothing}, Δ) = nothing
back_(::Call{Nothing}, _, _) = nothing
back_(::Call{Missing}, _, _) = error("`back!` was already used")
accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ)
function back(x::Tracked, Δ)
function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1
if ref > 0 || isdefined(x, :grad)
if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ)
else
x.grad = Δ
end
ref == 0 && back_(x.f, x.grad)
grad = if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ)
elseif ref > 0
x.grad = Δ
else
ref == 0 && back_(x.f, Δ)
Δ
end
if ref == 0
back_(x.f, grad, once)
once && !x.isleaf && (x.f = Call(missing, ()))
end
return
end
back(::Nothing, _) = return
back(::Nothing, _, _) = return
# Interface methods
# TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update.
# (but only if you re-use intermediate values between passes)
# Refcounts are also probably not safe in some situations (e.g. back called
# from within a backpropagator)
function back!(x, Δ)
function back!(x, Δ; once = true)
istracked(x) || return
scan(x)
back(tracker(x), Δ)
back(tracker(x), Δ, once)
return
end

View File

@ -10,10 +10,10 @@ tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
function back!(x::TrackedReal)
function back!(x::TrackedReal; once = true)
isinf(x) && error("Loss is Inf")
isnan(x) && error("Loss is NaN")
return back!(x, 1)
return back!(x, 1, once = once)
end
function Base.show(io::IO, x::TrackedReal)
@ -123,8 +123,8 @@ function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1])
end
function back_(c::Call{typeof(collect)}, Δ)
foreach(back, c.args[1], data(Δ))
function back_(c::Call{typeof(collect)}, Δ, once)
foreach((x, d) -> back(x, d, once), c.args[1], data(Δ))
end
function back_(g::Grads, c::Call{typeof(collect)}, Δ)

View File

@ -147,9 +147,9 @@ function jacobian(m,x)
n = length(x)
J = Matrix{eltype(x)}(undef,n,k)
for i = 1:k
Flux.back!(y[i]) # Populate gradient accumulator
Flux.back!(y[i], once = false) # Populate gradient accumulator
J[:,i] = xp.grad
xp.grad .*= 0 # Reset gradient accumulator
xp.grad .= 0 # Reset gradient accumulator
end
J'
end

View File

@ -232,10 +232,10 @@ end
@testset "Intermediates" begin
x = param([1])
l = sum((x .+ x).^2)
Flux.back!(l)
Flux.back!(l, once = false)
@test x.grad == [8]
x.grad .= 0
Flux.back!(l)
Flux.back!(l, once = false)
@test x.grad == [8]
end