in place implicit gradients

This commit is contained in:
Mike Innes 2019-02-28 14:08:01 +00:00
parent 8b4bc7cc52
commit cd091ad005

View File

@ -68,13 +68,30 @@ function back!(x, Δ; once = true)
return
end
function extract_grad!(x)
= copy(grad(x))
= nobacksies("Use `gradient(...; nest = true)` for nested derivatives", )
tracker(x).grad = zero_grad!(grad(x))
return
end
function gradient_(f, xs...)
xs = param.(data.(xs))
l = f(xs...)
losscheck(l)
back!(l)
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
grad.(xs))
extract_grad!.(xs)
end
function gradient_(f, xs::Params)
l = f()
losscheck(l)
back!(l)
gs = Grads()
for x in xs
gs[tracker(x)] = extract_grad!(x)
end
return gs
end
# Out-of-place gradients
@ -137,8 +154,6 @@ end
gradient(f, xs...; nest = false) =
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
gradient(f, ps::Params) = gradient_nested(f, ps)
# Jacobians and Hessians
import ..Flux