in place implicit gradients

This commit is contained in:
Mike Innes 2019-02-28 14:08:01 +00:00
parent 8b4bc7cc52
commit cd091ad005

View File

@ -68,13 +68,30 @@ function back!(x, Δ; once = true)
return return
end end
function extract_grad!(x)
= copy(grad(x))
= nobacksies("Use `gradient(...; nest = true)` for nested derivatives", )
tracker(x).grad = zero_grad!(grad(x))
return
end
function gradient_(f, xs...) function gradient_(f, xs...)
xs = param.(data.(xs)) xs = param.(data.(xs))
l = f(xs...) l = f(xs...)
losscheck(l) losscheck(l)
back!(l) back!(l)
nobacksies("Use `gradient(...; nest = true)` for nested derivatives", extract_grad!.(xs)
grad.(xs)) end
function gradient_(f, xs::Params)
l = f()
losscheck(l)
back!(l)
gs = Grads()
for x in xs
gs[tracker(x)] = extract_grad!(x)
end
return gs
end end
# Out-of-place gradients # Out-of-place gradients
@ -137,8 +154,6 @@ end
gradient(f, xs...; nest = false) = gradient(f, xs...; nest = false) =
nest ? gradient_nested(f, xs...) : gradient_(f, xs...) nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
gradient(f, ps::Params) = gradient_nested(f, ps)
# Jacobians and Hessians # Jacobians and Hessians
import ..Flux import ..Flux