diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 0dda0082..03fe14bb 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -68,13 +68,30 @@ function back!(x, Δ; once = true) return end +function extract_grad!(x) + x̄ = copy(grad(x)) + x̄ = nobacksies("Use `gradient(...; nest = true)` for nested derivatives", x̄) + tracker(x).grad = zero_grad!(grad(x)) + return x̄ +end + function gradient_(f, xs...) xs = param.(data.(xs)) l = f(xs...) losscheck(l) back!(l) - nobacksies("Use `gradient(...; nest = true)` for nested derivatives", - grad.(xs)) + extract_grad!.(xs) +end + +function gradient_(f, xs::Params) + l = f() + losscheck(l) + back!(l) + gs = Grads() + for x in xs + gs[tracker(x)] = extract_grad!(x) + end + return gs end # Out-of-place gradients @@ -137,8 +154,6 @@ end gradient(f, xs...; nest = false) = nest ? gradient_nested(f, xs...) : gradient_(f, xs...) -gradient(f, ps::Params) = gradient_nested(f, ps) - # Jacobians and Hessians import ..Flux