Merge pull request #465 from FluxML/mji/once

Destroy AD graph when doing in-place gradients
This commit is contained in:
Mike J Innes 2018-10-31 14:14:38 +00:00 committed by GitHub
commit 70283e1971
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 24 deletions

View File

@ -19,47 +19,50 @@ function scan(x)
return return
end end
function back_(c::Call, Δ) function back_(c::Call, Δ, once)
Δs = c.func(Δ) Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) || (Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))") error("Gradient is not a tuple of length $(length(c.args))")
foreach(back, c.args, data.(Δs)) foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
end end
back_(::Call{Nothing}, Δ) = nothing back_(::Call{Nothing}, Δ, once) = nothing
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
accum!(x, Δ) = x .+ Δ accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ) accum!(x::AbstractArray, Δ) = (x .+= Δ)
function back(x::Tracked, Δ) function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return) x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1 ref = x.ref -= 1
if ref > 0 || isdefined(x, :grad) grad = if isdefined(x, :grad)
if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ) x.grad = accum!(x.grad, Δ)
else elseif ref > 0
x.grad = Δ x.grad = Δ
end
ref == 0 && back_(x.f, x.grad)
else else
ref == 0 && back_(x.f, Δ) Δ
end
if ref == 0
back_(x.f, grad, once)
once && !x.isleaf && (x.f = Call(missing, ()))
end end
return return
end end
back(::Nothing, _) = return back(::Nothing, Δ, once) = return
# Interface methods # Interface methods
# TODO: if an error occurs in `back` the refcounts will be broken # TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update. # and `back` will silently fail to update.
# (but only if you re-use intermediate values between passes)
# Refcounts are also probably not safe in some situations (e.g. back called # Refcounts are also probably not safe in some situations (e.g. back called
# from within a backpropagator) # from within a backpropagator)
function back!(x, Δ) function back!(x, Δ; once = true)
istracked(x) || return istracked(x) || return
scan(x) scan(x)
back(tracker(x), Δ) back(tracker(x), Δ, once)
return return
end end
@ -91,12 +94,12 @@ Grads() = Grads(IdDict())
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps)) Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x] Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x) function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x") istracked(x) || error("Object not tracked: $x")
g[tracker(x)] g[tracker(x)]
end end
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ) function back_(g::Grads, c::Call, Δ)

View File

@ -10,10 +10,10 @@ tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x))) track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
function back!(x::TrackedReal) function back!(x::TrackedReal; once = true)
isinf(x) && error("Loss is Inf") isinf(x) && error("Loss is Inf")
isnan(x) && error("Loss is NaN") isnan(x) && error("Loss is NaN")
return back!(x, 1) return back!(x, 1, once = once)
end end
function Base.show(io::IO, x::TrackedReal) function Base.show(io::IO, x::TrackedReal)
@ -123,8 +123,8 @@ function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1]) foreach(scan, c.args[1])
end end
function back_(c::Call{typeof(collect)}, Δ) function back_(c::Call{typeof(collect)}, Δ, once)
foreach(back, c.args[1], data(Δ)) foreach((x, d) -> back(x, d, once), c.args[1], data(Δ))
end end
function back_(g::Grads, c::Call{typeof(collect)}, Δ) function back_(g::Grads, c::Call{typeof(collect)}, Δ)

View File

@ -147,9 +147,9 @@ function jacobian(m,x)
n = length(x) n = length(x)
J = Matrix{eltype(x)}(undef,n,k) J = Matrix{eltype(x)}(undef,n,k)
for i = 1:k for i = 1:k
Flux.back!(y[i]) # Populate gradient accumulator Flux.back!(y[i], once = false) # Populate gradient accumulator
J[:,i] = xp.grad J[:,i] = xp.grad
xp.grad .*= 0 # Reset gradient accumulator xp.grad .= 0 # Reset gradient accumulator
end end
J' J'
end end

View File

@ -237,10 +237,10 @@ end
@testset "Intermediates" begin @testset "Intermediates" begin
x = param([1]) x = param([1])
l = sum((x .+ x).^2) l = sum((x .+ x).^2)
Flux.back!(l) Flux.back!(l, once = false)
@test x.grad == [8] @test x.grad == [8]
x.grad .= 0 x.grad .= 0
Flux.back!(l) Flux.back!(l, once = false)
@test x.grad == [8] @test x.grad == [8]
end end