in place implicit gradients
This commit is contained in:
parent
8b4bc7cc52
commit
cd091ad005
@ -68,13 +68,30 @@ function back!(x, Δ; once = true)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function extract_grad!(x)
|
||||||
|
x̄ = copy(grad(x))
|
||||||
|
x̄ = nobacksies("Use `gradient(...; nest = true)` for nested derivatives", x̄)
|
||||||
|
tracker(x).grad = zero_grad!(grad(x))
|
||||||
|
return x̄
|
||||||
|
end
|
||||||
|
|
||||||
function gradient_(f, xs...)
|
function gradient_(f, xs...)
|
||||||
xs = param.(data.(xs))
|
xs = param.(data.(xs))
|
||||||
l = f(xs...)
|
l = f(xs...)
|
||||||
losscheck(l)
|
losscheck(l)
|
||||||
back!(l)
|
back!(l)
|
||||||
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
|
extract_grad!.(xs)
|
||||||
grad.(xs))
|
end
|
||||||
|
|
||||||
|
function gradient_(f, xs::Params)
|
||||||
|
l = f()
|
||||||
|
losscheck(l)
|
||||||
|
back!(l)
|
||||||
|
gs = Grads()
|
||||||
|
for x in xs
|
||||||
|
gs[tracker(x)] = extract_grad!(x)
|
||||||
|
end
|
||||||
|
return gs
|
||||||
end
|
end
|
||||||
|
|
||||||
# Out-of-place gradients
|
# Out-of-place gradients
|
||||||
@ -137,8 +154,6 @@ end
|
|||||||
gradient(f, xs...; nest = false) =
|
gradient(f, xs...; nest = false) =
|
||||||
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
||||||
|
|
||||||
gradient(f, ps::Params) = gradient_nested(f, ps)
|
|
||||||
|
|
||||||
# Jacobians and Hessians
|
# Jacobians and Hessians
|
||||||
|
|
||||||
import ..Flux
|
import ..Flux
|
||||||
|
Loading…
Reference in New Issue
Block a user