From 179a1e8407bf2390c7ba761396c1e56303e89e4e Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 21 Sep 2018 16:57:54 +0100 Subject: [PATCH 1/2] Correct Custom Gradients docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fixed a type signature that was incorrect. * Also, replaced `data(a)` with `a.data`. Don't know if the syntax has changed (recently). This may also need to be corrected in line 121. MWE: ```julia using Flux using Flux.Tracker using Flux.Tracker: forward, TrackedReal, track, @grad minus(a, b) = a - b minus(a::TrackedReal, b::TrackedReal) = Tracker.track(minus, a, b) @grad function minus(a, b) return minus(a.data, b.data), Δ -> (Δ, -Δ) end a, b = param(2), param(4) c = minus(a, b) # -2.0 (tracked) Tracker.back!(c) Tracker.grad(a) # 1.00 Tracker.grad(b) # -1.00 ``` --- docs/src/internals/tracker.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/internals/tracker.md b/docs/src/internals/tracker.md index 3d39451d..895e4b52 100644 --- a/docs/src/internals/tracker.md +++ b/docs/src/internals/tracker.md @@ -102,14 +102,14 @@ Firstly, we must tell the tracker system to stop when it sees a call to `minus`, ```julia using Flux.Tracker: TrackedReal, track, @grad -minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b) +minus(a::TrackedReal, b::TrackedReal) = Tracker.track(minus, a, b) ``` `track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition. ```julia @grad function minus(a, b) - return minus(data(a),data(b)), Δ -> (Δ, -Δ) + return minus(a.data, b.data), Δ -> (Δ, -Δ) end ``` From 61c14afee42b513387503eb900e2ebb81fb15d77 Mon Sep 17 00:00:00 2001 From: harryscholes Date: Tue, 9 Oct 2018 13:05:38 +0100 Subject: [PATCH 2/2] Add usage example of custom gradients --- docs/src/internals/tracker.md | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/docs/src/internals/tracker.md b/docs/src/internals/tracker.md index 895e4b52..456a9129 100644 --- a/docs/src/internals/tracker.md +++ b/docs/src/internals/tracker.md @@ -100,16 +100,16 @@ minus(a, b) = a - b Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch: ```julia -using Flux.Tracker: TrackedReal, track, @grad +using Flux.Tracker: TrackedArray, track, @grad -minus(a::TrackedReal, b::TrackedReal) = Tracker.track(minus, a, b) +minus(a::TrackedArray, b::TrackedArray) = track(minus, a, b) ``` `track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition. ```julia @grad function minus(a, b) - return minus(a.data, b.data), Δ -> (Δ, -Δ) + return minus(data(a), data(b)), Δ -> (Δ, -Δ) end ``` @@ -121,6 +121,19 @@ Note that in the backpropagator we don't call `data(a)`; we *do* in fact want to @grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ) ``` +We can then calculate the first derivative of `minus` as follows: + +```julia +a = param([1,2,3]) +b = param([3,2,1]) + +c = minus(a, b) # [-2.0 (tracked), 0.0 (tracked), 2.0 (tracked)] + +Tracker.back!(c, 1) +Tracker.grad(a) # [1.00, 1.00, 1.00] +Tracker.grad(b) # [-1.00, -1.00, -1.00] +``` + For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed: ```julia