Merge pull request #465 from FluxML/mji/once
Destroy AD graph when doing in-place gradients
This commit is contained in:
commit
70283e1971
|
@ -19,47 +19,50 @@ function scan(x)
|
|||
return
|
||||
end
|
||||
|
||||
function back_(c::Call, Δ)
|
||||
function back_(c::Call, Δ, once)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach(back, c.args, data.(Δs))
|
||||
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
|
||||
end
|
||||
|
||||
back_(::Call{Nothing}, Δ) = nothing
|
||||
back_(::Call{Nothing}, Δ, once) = nothing
|
||||
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
|
||||
|
||||
accum!(x, Δ) = x .+ Δ
|
||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
|
||||
function back(x::Tracked, Δ)
|
||||
function back(x::Tracked, Δ, once)
|
||||
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
if ref > 0 || isdefined(x, :grad)
|
||||
if isdefined(x, :grad)
|
||||
x.grad = accum!(x.grad, Δ)
|
||||
else
|
||||
x.grad = Δ
|
||||
end
|
||||
ref == 0 && back_(x.f, x.grad)
|
||||
grad = if isdefined(x, :grad)
|
||||
x.grad = accum!(x.grad, Δ)
|
||||
elseif ref > 0
|
||||
x.grad = Δ
|
||||
else
|
||||
ref == 0 && back_(x.f, Δ)
|
||||
Δ
|
||||
end
|
||||
if ref == 0
|
||||
back_(x.f, grad, once)
|
||||
once && !x.isleaf && (x.f = Call(missing, ()))
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
back(::Nothing, _) = return
|
||||
back(::Nothing, Δ, once) = return
|
||||
|
||||
# Interface methods
|
||||
|
||||
# TODO: if an error occurs in `back` the refcounts will be broken
|
||||
# and `back` will silently fail to update.
|
||||
# (but only if you re-use intermediate values between passes)
|
||||
# Refcounts are also probably not safe in some situations (e.g. back called
|
||||
# from within a backpropagator)
|
||||
|
||||
function back!(x, Δ)
|
||||
function back!(x, Δ; once = true)
|
||||
istracked(x) || return
|
||||
scan(x)
|
||||
back(tracker(x), Δ)
|
||||
back(tracker(x), Δ, once)
|
||||
return
|
||||
end
|
||||
|
||||
|
@ -91,12 +94,12 @@ Grads() = Grads(IdDict())
|
|||
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||
|
||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||
|
||||
function Base.getindex(g::Grads, x)
|
||||
istracked(x) || error("Object not tracked: $x")
|
||||
g[tracker(x)]
|
||||
end
|
||||
|
||||
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
||||
|
||||
function back_(g::Grads, c::Call, Δ)
|
||||
|
|
|
@ -10,10 +10,10 @@ tracker(x::TrackedReal) = x.tracker
|
|||
|
||||
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
|
||||
|
||||
function back!(x::TrackedReal)
|
||||
function back!(x::TrackedReal; once = true)
|
||||
isinf(x) && error("Loss is Inf")
|
||||
isnan(x) && error("Loss is NaN")
|
||||
return back!(x, 1)
|
||||
return back!(x, 1, once = once)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, x::TrackedReal)
|
||||
|
@ -123,8 +123,8 @@ function scan(c::Call{typeof(collect)})
|
|||
foreach(scan, c.args[1])
|
||||
end
|
||||
|
||||
function back_(c::Call{typeof(collect)}, Δ)
|
||||
foreach(back, c.args[1], data(Δ))
|
||||
function back_(c::Call{typeof(collect)}, Δ, once)
|
||||
foreach((x, d) -> back(x, d, once), c.args[1], data(Δ))
|
||||
end
|
||||
|
||||
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
|
||||
|
|
|
@ -147,9 +147,9 @@ function jacobian(m,x)
|
|||
n = length(x)
|
||||
J = Matrix{eltype(x)}(undef,n,k)
|
||||
for i = 1:k
|
||||
Flux.back!(y[i]) # Populate gradient accumulator
|
||||
Flux.back!(y[i], once = false) # Populate gradient accumulator
|
||||
J[:,i] = xp.grad
|
||||
xp.grad .*= 0 # Reset gradient accumulator
|
||||
xp.grad .= 0 # Reset gradient accumulator
|
||||
end
|
||||
J'
|
||||
end
|
||||
|
|
|
@ -237,10 +237,10 @@ end
|
|||
@testset "Intermediates" begin
|
||||
x = param([1])
|
||||
l = sum((x .+ x).^2)
|
||||
Flux.back!(l)
|
||||
Flux.back!(l, once = false)
|
||||
@test x.grad == [8]
|
||||
x.grad .= 0
|
||||
Flux.back!(l)
|
||||
Flux.back!(l, once = false)
|
||||
@test x.grad == [8]
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in New Issue