destroy AD graph when doing in-place gradients

This commit is contained in:
Mike J Innes 2018-10-26 16:57:19 +01:00
parent 44ccdb7ca9
commit c21d768b7c
4 changed files with 26 additions and 23 deletions

View File

@ -19,47 +19,50 @@ function scan(x)
return return
end end
function back_(c::Call, Δ) function back_(c::Call, Δ, once)
Δs = c.func(Δ) Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) || (Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))") error("Gradient is not a tuple of length $(length(c.args))")
foreach(back, c.args, data.(Δs)) foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
end end
back_(::Call{Nothing}, Δ) = nothing back_(::Call{Nothing}, _, _) = nothing
back_(::Call{Missing}, _, _) = error("`back!` was already used")
accum!(x, Δ) = x .+ Δ accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ) accum!(x::AbstractArray, Δ) = (x .+= Δ)
function back(x::Tracked, Δ) function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return) x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1 ref = x.ref -= 1
if ref > 0 || isdefined(x, :grad) grad = if isdefined(x, :grad)
if isdefined(x, :grad) x.grad = accum!(x.grad, Δ)
x.grad = accum!(x.grad, Δ) elseif ref > 0
else x.grad = Δ
x.grad = Δ
end
ref == 0 && back_(x.f, x.grad)
else else
ref == 0 && back_(x.f, Δ) Δ
end
if ref == 0
back_(x.f, grad, once)
once && !x.isleaf && (x.f = Call(missing, ()))
end end
return return
end end
back(::Nothing, _) = return back(::Nothing, _, _) = return
# Interface methods # Interface methods
# TODO: if an error occurs in `back` the refcounts will be broken # TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update. # and `back` will silently fail to update.
# (but only if you re-use intermediate values between passes)
# Refcounts are also probably not safe in some situations (e.g. back called # Refcounts are also probably not safe in some situations (e.g. back called
# from within a backpropagator) # from within a backpropagator)
function back!(x, Δ) function back!(x, Δ; once = true)
istracked(x) || return istracked(x) || return
scan(x) scan(x)
back(tracker(x), Δ) back(tracker(x), Δ, once)
return return
end end

View File

@ -10,10 +10,10 @@ tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x))) track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
function back!(x::TrackedReal) function back!(x::TrackedReal; once = true)
isinf(x) && error("Loss is Inf") isinf(x) && error("Loss is Inf")
isnan(x) && error("Loss is NaN") isnan(x) && error("Loss is NaN")
return back!(x, 1) return back!(x, 1, once = once)
end end
function Base.show(io::IO, x::TrackedReal) function Base.show(io::IO, x::TrackedReal)
@ -123,8 +123,8 @@ function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1]) foreach(scan, c.args[1])
end end
function back_(c::Call{typeof(collect)}, Δ) function back_(c::Call{typeof(collect)}, Δ, once)
foreach(back, c.args[1], data(Δ)) foreach((x, d) -> back(x, d, once), c.args[1], data(Δ))
end end
function back_(g::Grads, c::Call{typeof(collect)}, Δ) function back_(g::Grads, c::Call{typeof(collect)}, Δ)

View File

@ -147,9 +147,9 @@ function jacobian(m,x)
n = length(x) n = length(x)
J = Matrix{eltype(x)}(undef,n,k) J = Matrix{eltype(x)}(undef,n,k)
for i = 1:k for i = 1:k
Flux.back!(y[i]) # Populate gradient accumulator Flux.back!(y[i], once = false) # Populate gradient accumulator
J[:,i] = xp.grad J[:,i] = xp.grad
xp.grad .*= 0 # Reset gradient accumulator xp.grad .= 0 # Reset gradient accumulator
end end
J' J'
end end

View File

@ -232,10 +232,10 @@ end
@testset "Intermediates" begin @testset "Intermediates" begin
x = param([1]) x = param([1])
l = sum((x .+ x).^2) l = sum((x .+ x).^2)
Flux.back!(l) Flux.back!(l, once = false)
@test x.grad == [8] @test x.grad == [8]
x.grad .= 0 x.grad .= 0
Flux.back!(l) Flux.back!(l, once = false)
@test x.grad == [8] @test x.grad == [8]
end end