destroy AD graph when doing in-place gradients
This commit is contained in:
parent
44ccdb7ca9
commit
c21d768b7c
@ -19,47 +19,50 @@ function scan(x)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
function back_(c::Call, Δ)
|
function back_(c::Call, Δ, once)
|
||||||
Δs = c.func(Δ)
|
Δs = c.func(Δ)
|
||||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||||
error("Gradient is not a tuple of length $(length(c.args))")
|
error("Gradient is not a tuple of length $(length(c.args))")
|
||||||
foreach(back, c.args, data.(Δs))
|
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
|
||||||
end
|
end
|
||||||
|
|
||||||
back_(::Call{Nothing}, Δ) = nothing
|
back_(::Call{Nothing}, _, _) = nothing
|
||||||
|
back_(::Call{Missing}, _, _) = error("`back!` was already used")
|
||||||
|
|
||||||
accum!(x, Δ) = x .+ Δ
|
accum!(x, Δ) = x .+ Δ
|
||||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||||
|
|
||||||
function back(x::Tracked, Δ)
|
function back(x::Tracked, Δ, once)
|
||||||
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
||||||
ref = x.ref -= 1
|
ref = x.ref -= 1
|
||||||
if ref > 0 || isdefined(x, :grad)
|
grad = if isdefined(x, :grad)
|
||||||
if isdefined(x, :grad)
|
|
||||||
x.grad = accum!(x.grad, Δ)
|
x.grad = accum!(x.grad, Δ)
|
||||||
else
|
elseif ref > 0
|
||||||
x.grad = Δ
|
x.grad = Δ
|
||||||
end
|
|
||||||
ref == 0 && back_(x.f, x.grad)
|
|
||||||
else
|
else
|
||||||
ref == 0 && back_(x.f, Δ)
|
Δ
|
||||||
|
end
|
||||||
|
if ref == 0
|
||||||
|
back_(x.f, grad, once)
|
||||||
|
once && !x.isleaf && (x.f = Call(missing, ()))
|
||||||
end
|
end
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
back(::Nothing, _) = return
|
back(::Nothing, _, _) = return
|
||||||
|
|
||||||
# Interface methods
|
# Interface methods
|
||||||
|
|
||||||
# TODO: if an error occurs in `back` the refcounts will be broken
|
# TODO: if an error occurs in `back` the refcounts will be broken
|
||||||
# and `back` will silently fail to update.
|
# and `back` will silently fail to update.
|
||||||
|
# (but only if you re-use intermediate values between passes)
|
||||||
# Refcounts are also probably not safe in some situations (e.g. back called
|
# Refcounts are also probably not safe in some situations (e.g. back called
|
||||||
# from within a backpropagator)
|
# from within a backpropagator)
|
||||||
|
|
||||||
function back!(x, Δ)
|
function back!(x, Δ; once = true)
|
||||||
istracked(x) || return
|
istracked(x) || return
|
||||||
scan(x)
|
scan(x)
|
||||||
back(tracker(x), Δ)
|
back(tracker(x), Δ, once)
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -10,10 +10,10 @@ tracker(x::TrackedReal) = x.tracker
|
|||||||
|
|
||||||
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
|
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
|
||||||
|
|
||||||
function back!(x::TrackedReal)
|
function back!(x::TrackedReal; once = true)
|
||||||
isinf(x) && error("Loss is Inf")
|
isinf(x) && error("Loss is Inf")
|
||||||
isnan(x) && error("Loss is NaN")
|
isnan(x) && error("Loss is NaN")
|
||||||
return back!(x, 1)
|
return back!(x, 1, once = once)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, x::TrackedReal)
|
function Base.show(io::IO, x::TrackedReal)
|
||||||
@ -123,8 +123,8 @@ function scan(c::Call{typeof(collect)})
|
|||||||
foreach(scan, c.args[1])
|
foreach(scan, c.args[1])
|
||||||
end
|
end
|
||||||
|
|
||||||
function back_(c::Call{typeof(collect)}, Δ)
|
function back_(c::Call{typeof(collect)}, Δ, once)
|
||||||
foreach(back, c.args[1], data(Δ))
|
foreach((x, d) -> back(x, d, once), c.args[1], data(Δ))
|
||||||
end
|
end
|
||||||
|
|
||||||
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
|
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
|
||||||
|
@ -147,9 +147,9 @@ function jacobian(m,x)
|
|||||||
n = length(x)
|
n = length(x)
|
||||||
J = Matrix{eltype(x)}(undef,n,k)
|
J = Matrix{eltype(x)}(undef,n,k)
|
||||||
for i = 1:k
|
for i = 1:k
|
||||||
Flux.back!(y[i]) # Populate gradient accumulator
|
Flux.back!(y[i], once = false) # Populate gradient accumulator
|
||||||
J[:,i] = xp.grad
|
J[:,i] = xp.grad
|
||||||
xp.grad .*= 0 # Reset gradient accumulator
|
xp.grad .= 0 # Reset gradient accumulator
|
||||||
end
|
end
|
||||||
J'
|
J'
|
||||||
end
|
end
|
||||||
|
@ -232,10 +232,10 @@ end
|
|||||||
@testset "Intermediates" begin
|
@testset "Intermediates" begin
|
||||||
x = param([1])
|
x = param([1])
|
||||||
l = sum((x .+ x).^2)
|
l = sum((x .+ x).^2)
|
||||||
Flux.back!(l)
|
Flux.back!(l, once = false)
|
||||||
@test x.grad == [8]
|
@test x.grad == [8]
|
||||||
x.grad .= 0
|
x.grad .= 0
|
||||||
Flux.back!(l)
|
Flux.back!(l, once = false)
|
||||||
@test x.grad == [8]
|
@test x.grad == [8]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user