Correct Custom Gradients docs
* Fixed a type signature that was incorrect. * Also, replaced `data(a)` with `a.data`. Don't know if the syntax has changed (recently). This may also need to be corrected in line 121. MWE: ```julia using Flux using Flux.Tracker using Flux.Tracker: forward, TrackedReal, track, @grad minus(a, b) = a - b minus(a::TrackedReal, b::TrackedReal) = Tracker.track(minus, a, b) @grad function minus(a, b) return minus(a.data, b.data), Δ -> (Δ, -Δ) end a, b = param(2), param(4) c = minus(a, b) # -2.0 (tracked) Tracker.back!(c) Tracker.grad(a) # 1.00 Tracker.grad(b) # -1.00 ```
This commit is contained in:
parent
02ecca4c61
commit
179a1e8407
|
@ -102,14 +102,14 @@ Firstly, we must tell the tracker system to stop when it sees a call to `minus`,
|
|||
```julia
|
||||
using Flux.Tracker: TrackedReal, track, @grad
|
||||
|
||||
minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
|
||||
minus(a::TrackedReal, b::TrackedReal) = Tracker.track(minus, a, b)
|
||||
```
|
||||
|
||||
`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition.
|
||||
|
||||
```julia
|
||||
@grad function minus(a, b)
|
||||
return minus(data(a),data(b)), Δ -> (Δ, -Δ)
|
||||
return minus(a.data, b.data), Δ -> (Δ, -Δ)
|
||||
end
|
||||
```
|
||||
|
||||
|
|
Loading…
Reference in New Issue