Merge pull request #409 from harryscholes/patch-2

Correct Custom Gradients docs
This commit is contained in:
Mike J Innes 2018-10-09 14:09:09 +01:00 committed by GitHub
commit 3285afa45a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -100,16 +100,16 @@ minus(a, b) = a - b
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch: Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
```julia ```julia
using Flux.Tracker: TrackedReal, track, @grad using Flux.Tracker: TrackedArray, track, @grad
minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b) minus(a::TrackedArray, b::TrackedArray) = track(minus, a, b)
``` ```
`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition. `track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition.
```julia ```julia
@grad function minus(a, b) @grad function minus(a, b)
return minus(data(a),data(b)), Δ -> (Δ, -Δ) return minus(data(a), data(b)), Δ -> (Δ, -Δ)
end end
``` ```
@ -121,6 +121,19 @@ Note that in the backpropagator we don't call `data(a)`; we *do* in fact want to
@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ) @grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ)
``` ```
We can then calculate the first derivative of `minus` as follows:
```julia
a = param([1,2,3])
b = param([3,2,1])
c = minus(a, b) # [-2.0 (tracked), 0.0 (tracked), 2.0 (tracked)]
Tracker.back!(c, 1)
Tracker.grad(a) # [1.00, 1.00, 1.00]
Tracker.grad(b) # [-1.00, -1.00, -1.00]
```
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed: For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
```julia ```julia