From cd091ad005c2a1e6563f3689ed6637bd7c4ac152 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Thu, 28 Feb 2019 14:08:01 +0000 Subject: [PATCH] in place implicit gradients --- src/tracker/back.jl | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 0dda0082..03fe14bb 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -68,13 +68,30 @@ function back!(x, Δ; once = true) return end +function extract_grad!(x) + x̄ = copy(grad(x)) + x̄ = nobacksies("Use `gradient(...; nest = true)` for nested derivatives", x̄) + tracker(x).grad = zero_grad!(grad(x)) + return x̄ +end + function gradient_(f, xs...) xs = param.(data.(xs)) l = f(xs...) losscheck(l) back!(l) - nobacksies("Use `gradient(...; nest = true)` for nested derivatives", - grad.(xs)) + extract_grad!.(xs) +end + +function gradient_(f, xs::Params) + l = f() + losscheck(l) + back!(l) + gs = Grads() + for x in xs + gs[tracker(x)] = extract_grad!(x) + end + return gs end # Out-of-place gradients @@ -137,8 +154,6 @@ end gradient(f, xs...; nest = false) = nest ? gradient_nested(f, xs...) : gradient_(f, xs...) -gradient(f, ps::Params) = gradient_nested(f, ps) - # Jacobians and Hessians import ..Flux