params update

This commit is contained in:
Mike J Innes 2019-01-28 14:14:41 +00:00
parent bf0b5c5cef
commit e1cac76a34
1 changed files with 5 additions and 6 deletions

View File

@ -28,10 +28,10 @@ When a function has many parameters, we can pass them all in explicitly:
f(W, b, x) = W * x + b
Tracker.gradient(f, 2, 3, 4)
(4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
# (4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
```
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all of them at once.
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all `params` at once.
```julia
W = param(2) # 2.0 (tracked)
@ -39,14 +39,13 @@ b = param(3) # 3.0 (tracked)
f(x) = W * x + b
params = Params([W, b])
grads = Tracker.gradient(() -> f(4), params)
grads = Tracker.gradient(() -> f(4), params(W, b))
grads[W] # 4.0
grads[b] # 1.0
```
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `Params` tell it what to differentiate.
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
@ -77,7 +76,7 @@ using Flux.Tracker
W = param(W)
b = param(b)
gs = Tracker.gradient(() -> loss(x, y), Params([W, b]))
gs = Tracker.gradient(() -> loss(x, y), params(W, b))
```
Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent.