Merge pull request #575 from FluxML/mji/update
Clean up parameter update API
This commit is contained in:
commit
8386a49bf9
|
@ -3,7 +3,7 @@
|
|||
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
|
||||
|
||||
```julia
|
||||
using Flux.Tracker
|
||||
using Flux, Flux.Tracker
|
||||
|
||||
W = param(rand(2, 5))
|
||||
b = param(rand(2))
|
||||
|
@ -14,8 +14,8 @@ loss(x, y) = sum((predict(x) .- y).^2)
|
|||
x, y = rand(5), rand(2) # Dummy data
|
||||
l = loss(x, y) # ~ 3
|
||||
|
||||
params = Params([W, b])
|
||||
grads = Tracker.gradient(() -> loss(x, y), params)
|
||||
θ = Params([W, b])
|
||||
grads = Tracker.gradient(() -> loss(x, y), θ)
|
||||
```
|
||||
|
||||
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
|
||||
|
@ -35,7 +35,7 @@ Running this will alter the parameters `W` and `b` and our loss should go down.
|
|||
opt = Descent(0.1) # Gradient descent with learning rate 0.1
|
||||
|
||||
for p in (W, b)
|
||||
update!(opt, p, -η * grads[p])
|
||||
update!(opt, p, grads[p])
|
||||
end
|
||||
```
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ using Flux: Params
|
|||
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
|
||||
|
||||
# legacy update rule
|
||||
updaterule(opt, ps) = () -> update!(opt, ps)
|
||||
updaterule(opt, ps) = () -> _update_params!(opt, ps)
|
||||
|
||||
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
|
||||
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
|
||||
|
@ -117,7 +117,7 @@ struct OldOptimiser
|
|||
func
|
||||
end
|
||||
|
||||
update!(opt::OldOptimiser, ps) = opt.func()
|
||||
_update_params!(opt::OldOptimiser, ps) = opt.func()
|
||||
|
||||
# Train function
|
||||
function train!(loss, data, opt; cb = () -> ())
|
||||
|
|
|
@ -18,7 +18,7 @@ end
|
|||
|
||||
Descent() = Descent(0.1)
|
||||
|
||||
function update!(o::Descent, x, Δ)
|
||||
function apply!(o::Descent, x, Δ)
|
||||
Δ .*= o.eta
|
||||
end
|
||||
|
||||
|
@ -35,7 +35,7 @@ end
|
|||
|
||||
Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
||||
|
||||
function update!(o::Momentum, x, Δ)
|
||||
function apply!(o::Momentum, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
@. v = ρ * v - η * Δ
|
||||
|
@ -55,7 +55,7 @@ end
|
|||
|
||||
Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
||||
|
||||
function update!(o::Nesterov, x, Δ)
|
||||
function apply!(o::Nesterov, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
||||
|
@ -78,7 +78,7 @@ end
|
|||
|
||||
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
||||
|
||||
function update!(o::RMSProp, x, Δ)
|
||||
function apply!(o::RMSProp, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
acc = get!(o.acc, x, zero(x))::typeof(x)
|
||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||
|
@ -98,7 +98,7 @@ end
|
|||
|
||||
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
|
||||
|
||||
function update!(o::ADAM, x, Δ)
|
||||
function apply!(o::ADAM, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
|
@ -122,7 +122,7 @@ end
|
|||
|
||||
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
|
||||
|
||||
function update!(o::AdaMax, x, Δ)
|
||||
function apply!(o::AdaMax, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
|
@ -145,7 +145,7 @@ end
|
|||
|
||||
ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
||||
|
||||
function update!(o::ADAGrad, x, Δ)
|
||||
function apply!(o::ADAGrad, x, Δ)
|
||||
η = o.eta
|
||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
|
||||
@. acc += Δ^2
|
||||
|
@ -165,7 +165,7 @@ end
|
|||
|
||||
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
|
||||
|
||||
function update!(o::ADADelta, x, Δ)
|
||||
function apply!(o::ADADelta, x, Δ)
|
||||
ρ = o.rho
|
||||
acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
|
||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||
|
@ -188,7 +188,7 @@ end
|
|||
|
||||
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
|
||||
|
||||
function update!(o::AMSGrad, x, Δ)
|
||||
function apply!(o::AMSGrad, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
|
@ -211,7 +211,7 @@ end
|
|||
|
||||
NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
|
||||
|
||||
function update!(o::NADAM, x, Δ)
|
||||
function apply!(o::NADAM, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
β1p, β2p = o.beta
|
||||
mt, vt = get!(o.state, x, (zero(x), zero(x)))
|
||||
|
@ -250,9 +250,9 @@ Optimiser(o...) = Optimiser(Any[o...])
|
|||
|
||||
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
|
||||
|
||||
function update!(o::Optimiser, x, Δ)
|
||||
function apply!(o::Optimiser, x, Δ)
|
||||
for opt in o.os
|
||||
Δ = update!(opt, x, Δ)
|
||||
Δ = apply!(opt, x, Δ)
|
||||
end
|
||||
return Δ
|
||||
end
|
||||
|
@ -272,7 +272,7 @@ end
|
|||
|
||||
InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
|
||||
|
||||
function update!(o::InvDecay, x, Δ)
|
||||
function apply!(o::InvDecay, x, Δ)
|
||||
γ = o.gamma
|
||||
n = get!(o.state, x, 1)
|
||||
Δ .*= 1 / (1 + γ * n)
|
||||
|
@ -300,7 +300,7 @@ end
|
|||
|
||||
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
|
||||
|
||||
function update!(o::ExpDecay, x, Δ)
|
||||
function apply!(o::ExpDecay, x, Δ)
|
||||
η, s, decay = o.eta, o.step, o.decay
|
||||
n = o.current[x] = get(o.current, x, 0) + 1
|
||||
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
|
||||
|
@ -321,7 +321,7 @@ end
|
|||
|
||||
WeightDecay() = WeightDecay(0)
|
||||
|
||||
function update!(o::WeightDecay, x, Δ)
|
||||
function apply!(o::WeightDecay, x, Δ)
|
||||
wd = o.wd
|
||||
@. Δ += wd * x
|
||||
end
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
using Juno
|
||||
using Flux.Tracker: data, grad, back!
|
||||
import Flux.Tracker: data, grad, back!, update!
|
||||
import Base.depwarn
|
||||
|
||||
function update!(opt, xs)
|
||||
function update!(opt, x, x̄)
|
||||
update!(x, apply!(opt, x, copy(data(x̄))))
|
||||
end
|
||||
|
||||
function _update_params!(opt, xs)
|
||||
for x in xs
|
||||
Δ = update!(opt, x.data, x.grad)
|
||||
Δ = apply!(opt, x.data, x.grad)
|
||||
x.data .-= Δ
|
||||
Δ .= 0
|
||||
end
|
||||
|
@ -69,7 +73,7 @@ function train!(loss, ps, data, opt; cb = () -> ())
|
|||
try
|
||||
l = loss(d...)
|
||||
@interrupts back!(l)
|
||||
update!(opt, ps)
|
||||
_update_params!(opt, ps)
|
||||
if cb() == :stop
|
||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||
break
|
||||
|
|
|
@ -17,7 +17,7 @@ using Test
|
|||
for t = 1: 10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
delta = Optimise.update!(opt, w′.data, w′.grad)
|
||||
delta = Optimise.apply!(opt, w′.data, w′.grad)
|
||||
w′.data .-= delta
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
|
@ -33,7 +33,7 @@ end
|
|||
for t = 1:10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
delta = Optimise.update!(opt, w′.data, w′.grad)
|
||||
delta = Optimise.apply!(opt, w′.data, w′.grad)
|
||||
w′.data .-= delta
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
|
|
Loading…
Reference in New Issue