avoid implementation details in docs
This commit is contained in:
parent
d76e790818
commit
5d8b63dc65
|
@ -31,12 +31,12 @@ back!(l)
|
|||
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then accumulates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
|
||||
|
||||
```julia
|
||||
W.grad
|
||||
using Flux.Tracker: grad, update!
|
||||
|
||||
# Update the parameter
|
||||
W.data .-= 0.1(W.grad)
|
||||
# Reset the gradient
|
||||
W.grad .= 0
|
||||
Δ = grad(W)
|
||||
|
||||
# Update the parameter and reset the gradient
|
||||
update!(W, -0.1Δ)
|
||||
|
||||
loss(x, y) # ~ 2.5
|
||||
```
|
||||
|
|
|
@ -17,16 +17,17 @@ back!(l)
|
|||
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
|
||||
|
||||
```julia
|
||||
function update()
|
||||
using Flux.Tracker: grad, update!
|
||||
|
||||
function sgd()
|
||||
η = 0.1 # Learning Rate
|
||||
for p in (W, b)
|
||||
p.data .-= η .* p.grad # Apply the update
|
||||
p.grad .= 0 # Clear the gradient
|
||||
update!(p, -η * grad(p))
|
||||
end
|
||||
end
|
||||
```
|
||||
|
||||
If we call `update`, the parameters `W` and `b` will change and our loss should go down.
|
||||
If we call `sgd`, the parameters `W` and `b` will change and our loss should go down.
|
||||
|
||||
There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum.
|
||||
|
||||
|
|
|
@ -47,6 +47,12 @@ isleaf(x::Tracked) = x.f == Call(nothing)
|
|||
data(x::Tracked) = x.data
|
||||
grad(x::Tracked) = x.grad
|
||||
|
||||
function update!(x, Δ)
|
||||
tracker(x).data += Δ
|
||||
tracker(x).grad .= 0
|
||||
return x
|
||||
end
|
||||
|
||||
include("back.jl")
|
||||
include("scalar.jl")
|
||||
include("array.jl")
|
||||
|
|
Loading…
Reference in New Issue