extend update! with an optimiser

This commit is contained in:
Mike J Innes 2019-01-28 14:10:09 +00:00
parent 0f2975d905
commit 0f8a4a48c6
2 changed files with 9 additions and 5 deletions

View File

@ -3,7 +3,7 @@
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`. Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
```julia ```julia
using Flux.Tracker using Flux, Flux.Tracker
W = param(rand(2, 5)) W = param(rand(2, 5))
b = param(rand(2)) b = param(rand(2))
@ -14,8 +14,8 @@ loss(x, y) = sum((predict(x) .- y).^2)
x, y = rand(5), rand(2) # Dummy data x, y = rand(5), rand(2) # Dummy data
l = loss(x, y) # ~ 3 l = loss(x, y) # ~ 3
params = Params([W, b]) θ = Params([W, b])
grads = Tracker.gradient(() -> loss(x, y), params) grads = Tracker.gradient(() -> loss(x, y), θ)
``` ```
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that: We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
@ -35,7 +35,7 @@ Running this will alter the parameters `W` and `b` and our loss should go down.
opt = Descent(0.1) # Gradient descent with learning rate 0.1 opt = Descent(0.1) # Gradient descent with learning rate 0.1
for p in (W, b) for p in (W, b)
update!(opt, p, -η * grads[p]) update!(opt, p, grads[p])
end end
``` ```

View File

@ -1,7 +1,11 @@
using Juno using Juno
using Flux.Tracker: data, grad, back! import Flux.Tracker: data, grad, back!, update!
import Base.depwarn import Base.depwarn
function update!(opt, x, )
update!(x, apply!(opt, x, copy(data())))
end
function _update_params!(opt, xs) function _update_params!(opt, xs)
for x in xs for x in xs
Δ = apply!(opt, x.data, x.grad) Δ = apply!(opt, x.data, x.grad)