Add usage example of custom gradients
This commit is contained in:
parent
179a1e8407
commit
61c14afee4
|
@ -100,16 +100,16 @@ minus(a, b) = a - b
|
|||
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
|
||||
|
||||
```julia
|
||||
using Flux.Tracker: TrackedReal, track, @grad
|
||||
using Flux.Tracker: TrackedArray, track, @grad
|
||||
|
||||
minus(a::TrackedReal, b::TrackedReal) = Tracker.track(minus, a, b)
|
||||
minus(a::TrackedArray, b::TrackedArray) = track(minus, a, b)
|
||||
```
|
||||
|
||||
`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition.
|
||||
|
||||
```julia
|
||||
@grad function minus(a, b)
|
||||
return minus(a.data, b.data), Δ -> (Δ, -Δ)
|
||||
return minus(data(a), data(b)), Δ -> (Δ, -Δ)
|
||||
end
|
||||
```
|
||||
|
||||
|
@ -121,6 +121,19 @@ Note that in the backpropagator we don't call `data(a)`; we *do* in fact want to
|
|||
@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ)
|
||||
```
|
||||
|
||||
We can then calculate the first derivative of `minus` as follows:
|
||||
|
||||
```julia
|
||||
a = param([1,2,3])
|
||||
b = param([3,2,1])
|
||||
|
||||
c = minus(a, b) # [-2.0 (tracked), 0.0 (tracked), 2.0 (tracked)]
|
||||
|
||||
Tracker.back!(c, 1)
|
||||
Tracker.grad(a) # [1.00, 1.00, 1.00]
|
||||
Tracker.grad(b) # [-1.00, -1.00, -1.00]
|
||||
```
|
||||
|
||||
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
|
||||
|
||||
```julia
|
||||
|
|
Loading…
Reference in New Issue