in place implicit gradients
This commit is contained in:
parent
8b4bc7cc52
commit
cd091ad005
@ -68,13 +68,30 @@ function back!(x, Δ; once = true)
|
||||
return
|
||||
end
|
||||
|
||||
function extract_grad!(x)
|
||||
x̄ = copy(grad(x))
|
||||
x̄ = nobacksies("Use `gradient(...; nest = true)` for nested derivatives", x̄)
|
||||
tracker(x).grad = zero_grad!(grad(x))
|
||||
return x̄
|
||||
end
|
||||
|
||||
function gradient_(f, xs...)
|
||||
xs = param.(data.(xs))
|
||||
l = f(xs...)
|
||||
losscheck(l)
|
||||
back!(l)
|
||||
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
|
||||
grad.(xs))
|
||||
extract_grad!.(xs)
|
||||
end
|
||||
|
||||
function gradient_(f, xs::Params)
|
||||
l = f()
|
||||
losscheck(l)
|
||||
back!(l)
|
||||
gs = Grads()
|
||||
for x in xs
|
||||
gs[tracker(x)] = extract_grad!(x)
|
||||
end
|
||||
return gs
|
||||
end
|
||||
|
||||
# Out-of-place gradients
|
||||
@ -137,8 +154,6 @@ end
|
||||
gradient(f, xs...; nest = false) =
|
||||
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
||||
|
||||
gradient(f, ps::Params) = gradient_nested(f, ps)
|
||||
|
||||
# Jacobians and Hessians
|
||||
|
||||
import ..Flux
|
||||
|
Loading…
Reference in New Issue
Block a user