params update
This commit is contained in:
parent
bf0b5c5cef
commit
e1cac76a34
@ -28,10 +28,10 @@ When a function has many parameters, we can pass them all in explicitly:
|
|||||||
f(W, b, x) = W * x + b
|
f(W, b, x) = W * x + b
|
||||||
|
|
||||||
Tracker.gradient(f, 2, 3, 4)
|
Tracker.gradient(f, 2, 3, 4)
|
||||||
(4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
|
# (4.0 (tracked), 1.0 (tracked), 2.0 (tracked))
|
||||||
```
|
```
|
||||||
|
|
||||||
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all of them at once.
|
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all `params` at once.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
W = param(2) # 2.0 (tracked)
|
W = param(2) # 2.0 (tracked)
|
||||||
@ -39,14 +39,13 @@ b = param(3) # 3.0 (tracked)
|
|||||||
|
|
||||||
f(x) = W * x + b
|
f(x) = W * x + b
|
||||||
|
|
||||||
params = Params([W, b])
|
grads = Tracker.gradient(() -> f(4), params(W, b))
|
||||||
grads = Tracker.gradient(() -> f(4), params)
|
|
||||||
|
|
||||||
grads[W] # 4.0
|
grads[W] # 4.0
|
||||||
grads[b] # 1.0
|
grads[b] # 1.0
|
||||||
```
|
```
|
||||||
|
|
||||||
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `Params` tell it what to differentiate.
|
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
|
||||||
|
|
||||||
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
|
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
|
||||||
|
|
||||||
@ -77,7 +76,7 @@ using Flux.Tracker
|
|||||||
W = param(W)
|
W = param(W)
|
||||||
b = param(b)
|
b = param(b)
|
||||||
|
|
||||||
gs = Tracker.gradient(() -> loss(x, y), Params([W, b]))
|
gs = Tracker.gradient(() -> loss(x, y), params(W, b))
|
||||||
```
|
```
|
||||||
|
|
||||||
Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent.
|
Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent.
|
||||||
|
Loading…
Reference in New Issue
Block a user