avoid implementation details in docs

This commit is contained in:
Mike Innes 2018-06-29 13:53:50 +01:00
parent d76e790818
commit 5d8b63dc65
3 changed files with 16 additions and 9 deletions

View File

@ -31,12 +31,12 @@ back!(l)
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then accumulates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
```julia
W.grad
using Flux.Tracker: grad, update!
# Update the parameter
W.data .-= 0.1(W.grad)
# Reset the gradient
W.grad .= 0
Δ = grad(W)
# Update the parameter and reset the gradient
update!(W, -0.1Δ)
loss(x, y) # ~ 2.5
```

View File

@ -17,16 +17,17 @@ back!(l)
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
```julia
function update()
using Flux.Tracker: grad, update!
function sgd()
η = 0.1 # Learning Rate
for p in (W, b)
p.data .-= η .* p.grad # Apply the update
p.grad .= 0 # Clear the gradient
update!(p, -η * grad(p))
end
end
```
If we call `update`, the parameters `W` and `b` will change and our loss should go down.
If we call `sgd`, the parameters `W` and `b` will change and our loss should go down.
There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum.

View File

@ -47,6 +47,12 @@ isleaf(x::Tracked) = x.f == Call(nothing)
data(x::Tracked) = x.data
grad(x::Tracked) = x.grad
function update!(x, Δ)
tracker(x).data += Δ
tracker(x).grad .= 0
return x
end
include("back.jl")
include("scalar.jl")
include("array.jl")