extend update! with an optimiser
This commit is contained in:
parent
0f2975d905
commit
0f8a4a48c6
@ -3,7 +3,7 @@
|
|||||||
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
|
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
using Flux.Tracker
|
using Flux, Flux.Tracker
|
||||||
|
|
||||||
W = param(rand(2, 5))
|
W = param(rand(2, 5))
|
||||||
b = param(rand(2))
|
b = param(rand(2))
|
||||||
@ -14,8 +14,8 @@ loss(x, y) = sum((predict(x) .- y).^2)
|
|||||||
x, y = rand(5), rand(2) # Dummy data
|
x, y = rand(5), rand(2) # Dummy data
|
||||||
l = loss(x, y) # ~ 3
|
l = loss(x, y) # ~ 3
|
||||||
|
|
||||||
params = Params([W, b])
|
θ = Params([W, b])
|
||||||
grads = Tracker.gradient(() -> loss(x, y), params)
|
grads = Tracker.gradient(() -> loss(x, y), θ)
|
||||||
```
|
```
|
||||||
|
|
||||||
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
|
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
|
||||||
@ -35,7 +35,7 @@ Running this will alter the parameters `W` and `b` and our loss should go down.
|
|||||||
opt = Descent(0.1) # Gradient descent with learning rate 0.1
|
opt = Descent(0.1) # Gradient descent with learning rate 0.1
|
||||||
|
|
||||||
for p in (W, b)
|
for p in (W, b)
|
||||||
update!(opt, p, -η * grads[p])
|
update!(opt, p, grads[p])
|
||||||
end
|
end
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
using Juno
|
using Juno
|
||||||
using Flux.Tracker: data, grad, back!
|
import Flux.Tracker: data, grad, back!, update!
|
||||||
import Base.depwarn
|
import Base.depwarn
|
||||||
|
|
||||||
|
function update!(opt, x, x̄)
|
||||||
|
update!(x, apply!(opt, x, copy(data(x̄))))
|
||||||
|
end
|
||||||
|
|
||||||
function _update_params!(opt, xs)
|
function _update_params!(opt, xs)
|
||||||
for x in xs
|
for x in xs
|
||||||
Δ = apply!(opt, x.data, x.grad)
|
Δ = apply!(opt, x.data, x.grad)
|
||||||
|
Loading…
Reference in New Issue
Block a user