Merge pull request #575 from FluxML/mji/update

Clean up parameter update API
This commit is contained in:
Mike J Innes 2019-01-28 15:26:57 +00:00 committed by GitHub
commit 8386a49bf9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 31 additions and 27 deletions

View File

@ -3,7 +3,7 @@
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
```julia
using Flux.Tracker
using Flux, Flux.Tracker
W = param(rand(2, 5))
b = param(rand(2))
@ -14,8 +14,8 @@ loss(x, y) = sum((predict(x) .- y).^2)
x, y = rand(5), rand(2) # Dummy data
l = loss(x, y) # ~ 3
params = Params([W, b])
grads = Tracker.gradient(() -> loss(x, y), params)
θ = Params([W, b])
grads = Tracker.gradient(() -> loss(x, y), θ)
```
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
@ -35,7 +35,7 @@ Running this will alter the parameters `W` and `b` and our loss should go down.
opt = Descent(0.1) # Gradient descent with learning rate 0.1
for p in (W, b)
update!(opt, p, -η * grads[p])
update!(opt, p, grads[p])
end
```

View File

@ -4,7 +4,7 @@ using Flux: Params
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
# legacy update rule
updaterule(opt, ps) = () -> update!(opt, ps)
updaterule(opt, ps) = () -> _update_params!(opt, ps)
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
@ -117,7 +117,7 @@ struct OldOptimiser
func
end
update!(opt::OldOptimiser, ps) = opt.func()
_update_params!(opt::OldOptimiser, ps) = opt.func()
# Train function
function train!(loss, data, opt; cb = () -> ())

View File

@ -18,7 +18,7 @@ end
Descent() = Descent(0.1)
function update!(o::Descent, x, Δ)
function apply!(o::Descent, x, Δ)
Δ .*= o.eta
end
@ -35,7 +35,7 @@ end
Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
function update!(o::Momentum, x, Δ)
function apply!(o::Momentum, x, Δ)
η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x)
@. v = ρ * v - η * Δ
@ -55,7 +55,7 @@ end
Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
function update!(o::Nesterov, x, Δ)
function apply!(o::Nesterov, x, Δ)
η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x)
d = @. ρ^2 * v - (1+ρ) * η * Δ
@ -78,7 +78,7 @@ end
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
function update!(o::RMSProp, x, Δ)
function apply!(o::RMSProp, x, Δ)
η, ρ = o.eta, o.rho
acc = get!(o.acc, x, zero(x))::typeof(x)
@. acc = ρ * acc + (1 - ρ) * Δ^2
@ -98,7 +98,7 @@ end
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
function update!(o::ADAM, x, Δ)
function apply!(o::ADAM, x, Δ)
η, β = o.eta, o.beta
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@ -122,7 +122,7 @@ end
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
function update!(o::AdaMax, x, Δ)
function apply!(o::AdaMax, x, Δ)
η, β = o.eta, o.beta
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@ -145,7 +145,7 @@ end
ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
function update!(o::ADAGrad, x, Δ)
function apply!(o::ADAGrad, x, Δ)
η = o.eta
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
@. acc += Δ^2
@ -165,7 +165,7 @@ end
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
function update!(o::ADADelta, x, Δ)
function apply!(o::ADADelta, x, Δ)
ρ = o.rho
acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
@. acc = ρ * acc + (1 - ρ) * Δ^2
@ -188,7 +188,7 @@ end
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
function update!(o::AMSGrad, x, Δ)
function apply!(o::AMSGrad, x, Δ)
η, β = o.eta, o.beta
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@ -211,7 +211,7 @@ end
NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
function update!(o::NADAM, x, Δ)
function apply!(o::NADAM, x, Δ)
η, β = o.eta, o.beta
β1p, β2p = o.beta
mt, vt = get!(o.state, x, (zero(x), zero(x)))
@ -250,9 +250,9 @@ Optimiser(o...) = Optimiser(Any[o...])
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
function update!(o::Optimiser, x, Δ)
function apply!(o::Optimiser, x, Δ)
for opt in o.os
Δ = update!(opt, x, Δ)
Δ = apply!(opt, x, Δ)
end
return Δ
end
@ -272,7 +272,7 @@ end
InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
function update!(o::InvDecay, x, Δ)
function apply!(o::InvDecay, x, Δ)
γ = o.gamma
n = get!(o.state, x, 1)
Δ .*= 1 / (1 + γ * n)
@ -300,7 +300,7 @@ end
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
function update!(o::ExpDecay, x, Δ)
function apply!(o::ExpDecay, x, Δ)
η, s, decay = o.eta, o.step, o.decay
n = o.current[x] = get(o.current, x, 0) + 1
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
@ -321,7 +321,7 @@ end
WeightDecay() = WeightDecay(0)
function update!(o::WeightDecay, x, Δ)
function apply!(o::WeightDecay, x, Δ)
wd = o.wd
@. Δ += wd * x
end

View File

@ -1,10 +1,14 @@
using Juno
using Flux.Tracker: data, grad, back!
import Flux.Tracker: data, grad, back!, update!
import Base.depwarn
function update!(opt, xs)
function update!(opt, x, )
update!(x, apply!(opt, x, copy(data())))
end
function _update_params!(opt, xs)
for x in xs
Δ = update!(opt, x.data, x.grad)
Δ = apply!(opt, x.data, x.grad)
x.data .-= Δ
Δ .= 0
end
@ -69,7 +73,7 @@ function train!(loss, ps, data, opt; cb = () -> ())
try
l = loss(d...)
@interrupts back!(l)
update!(opt, ps)
_update_params!(opt, ps)
if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break

View File

@ -17,7 +17,7 @@ using Test
for t = 1: 10^5
l = loss(rand(10))
back!(l)
delta = Optimise.update!(opt, w.data, w.grad)
delta = Optimise.apply!(opt, w.data, w.grad)
w.data .-= delta
end
@test Flux.mse(w, w) < 0.01
@ -33,7 +33,7 @@ end
for t = 1:10^5
l = loss(rand(10))
back!(l)
delta = Optimise.update!(opt, w.data, w.grad)
delta = Optimise.apply!(opt, w.data, w.grad)
w.data .-= delta
end
@test Flux.mse(w, w) < 0.01