Add usage example of custom gradients

This commit is contained in:
harryscholes 2018-10-09 13:05:38 +01:00
parent 179a1e8407
commit 61c14afee4
1 changed files with 16 additions and 3 deletions

View File

@ -100,16 +100,16 @@ minus(a, b) = a - b
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
```julia
using Flux.Tracker: TrackedReal, track, @grad
using Flux.Tracker: TrackedArray, track, @grad
minus(a::TrackedReal, b::TrackedReal) = Tracker.track(minus, a, b)
minus(a::TrackedArray, b::TrackedArray) = track(minus, a, b)
```
`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition.
```julia
@grad function minus(a, b)
return minus(a.data, b.data), Δ -> (Δ, -Δ)
return minus(data(a), data(b)), Δ -> (Δ, -Δ)
end
```
@ -121,6 +121,19 @@ Note that in the backpropagator we don't call `data(a)`; we *do* in fact want to
@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ)
```
We can then calculate the first derivative of `minus` as follows:
```julia
a = param([1,2,3])
b = param([3,2,1])
c = minus(a, b) # [-2.0 (tracked), 0.0 (tracked), 2.0 (tracked)]
Tracker.back!(c, 1)
Tracker.grad(a) # [1.00, 1.00, 1.00]
Tracker.grad(b) # [-1.00, -1.00, -1.00]
```
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
```julia