diff --git a/docs/src/internals/tracker.md b/docs/src/internals/tracker.md index 3d39451d..456a9129 100644 --- a/docs/src/internals/tracker.md +++ b/docs/src/internals/tracker.md @@ -100,16 +100,16 @@ minus(a, b) = a - b Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch: ```julia -using Flux.Tracker: TrackedReal, track, @grad +using Flux.Tracker: TrackedArray, track, @grad -minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b) +minus(a::TrackedArray, b::TrackedArray) = track(minus, a, b) ``` `track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition. ```julia @grad function minus(a, b) - return minus(data(a),data(b)), Δ -> (Δ, -Δ) + return minus(data(a), data(b)), Δ -> (Δ, -Δ) end ``` @@ -121,6 +121,19 @@ Note that in the backpropagator we don't call `data(a)`; we *do* in fact want to @grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ) ``` +We can then calculate the first derivative of `minus` as follows: + +```julia +a = param([1,2,3]) +b = param([3,2,1]) + +c = minus(a, b) # [-2.0 (tracked), 0.0 (tracked), 2.0 (tracked)] + +Tracker.back!(c, 1) +Tracker.grad(a) # [1.00, 1.00, 1.00] +Tracker.grad(b) # [-1.00, -1.00, -1.00] +``` + For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed: ```julia