params update
This commit is contained in:
parent
bf0b5c5cef
commit
e1cac76a34
|
@ -28,10 +28,10 @@ When a function has many parameters, we can pass them all in explicitly:
|
|||
f(W, b, x) = W * x + b
|
||||
|
||||
Tracker.gradient(f, 2, 3, 4)
|
||||
(4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
|
||||
# (4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
|
||||
```
|
||||
|
||||
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all of them at once.
|
||||
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all `params` at once.
|
||||
|
||||
```julia
|
||||
W = param(2) # 2.0 (tracked)
|
||||
|
@ -39,14 +39,13 @@ b = param(3) # 3.0 (tracked)
|
|||
|
||||
f(x) = W * x + b
|
||||
|
||||
params = Params([W, b])
|
||||
grads = Tracker.gradient(() -> f(4), params)
|
||||
grads = Tracker.gradient(() -> f(4), params(W, b))
|
||||
|
||||
grads[W] # 4.0
|
||||
grads[b] # 1.0
|
||||
```
|
||||
|
||||
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `Params` tell it what to differentiate.
|
||||
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
|
||||
|
||||
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
|
||||
|
||||
|
@ -77,7 +76,7 @@ using Flux.Tracker
|
|||
W = param(W)
|
||||
b = param(b)
|
||||
|
||||
gs = Tracker.gradient(() -> loss(x, y), Params([W, b]))
|
||||
gs = Tracker.gradient(() -> loss(x, y), params(W, b))
|
||||
```
|
||||
|
||||
Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent.
|
||||
|
|
Loading…
Reference in New Issue