params update

This commit is contained in:
Mike J Innes 2019-01-28 14:14:41 +00:00
parent bf0b5c5cef
commit e1cac76a34

View File

@ -28,10 +28,10 @@ When a function has many parameters, we can pass them all in explicitly:
f(W, b, x) = W * x + b f(W, b, x) = W * x + b
Tracker.gradient(f, 2, 3, 4) Tracker.gradient(f, 2, 3, 4)
(4.0 (tracked), 1.0 (tracked), 2.0 (tracked)) # (4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
``` ```
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all of them at once. But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all `params` at once.
```julia ```julia
W = param(2) # 2.0 (tracked) W = param(2) # 2.0 (tracked)
@ -39,14 +39,13 @@ b = param(3) # 3.0 (tracked)
f(x) = W * x + b f(x) = W * x + b
params = Params([W, b]) grads = Tracker.gradient(() -> f(4), params(W, b))
grads = Tracker.gradient(() -> f(4), params)
grads[W] # 4.0 grads[W] # 4.0
grads[b] # 1.0 grads[b] # 1.0
``` ```
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `Params` tell it what to differentiate. There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple. This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
@ -77,7 +76,7 @@ using Flux.Tracker
W = param(W) W = param(W)
b = param(b) b = param(b)
gs = Tracker.gradient(() -> loss(x, y), Params([W, b])) gs = Tracker.gradient(() -> loss(x, y), params(W, b))
``` ```
Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent. Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent.