Correct Custom Gradients docs

* Fixed a type signature that was incorrect.
* Also, replaced `data(a)` with `a.data`. Don't know if the syntax has changed (recently). This may also need to be corrected in line 121.

MWE:

```julia
using Flux
using Flux.Tracker
using Flux.Tracker: forward, TrackedReal, track, @grad

minus(a, b) = a - b
minus(a::TrackedReal, b::TrackedReal) = Tracker.track(minus, a, b)
@grad function minus(a, b)
    return minus(a.data, b.data), Δ -> (Δ, -Δ)
end

a, b = param(2), param(4)
c = minus(a, b)  # -2.0 (tracked)
Tracker.back!(c)

Tracker.grad(a)  # 1.00
Tracker.grad(b)  # -1.00
```
This commit is contained in:
Harry 2018-09-21 16:57:54 +01:00 committed by GitHub
parent 02ecca4c61
commit 179a1e8407
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -102,14 +102,14 @@ Firstly, we must tell the tracker system to stop when it sees a call to `minus`,
```julia
using Flux.Tracker: TrackedReal, track, @grad
minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
minus(a::TrackedReal, b::TrackedReal) = Tracker.track(minus, a, b)
```
`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition.
```julia
@grad function minus(a, b)
return minus(data(a),data(b)), Δ -> (Δ, -Δ)
return minus(a.data, b.data), Δ -> (Δ, -Δ)
end
```