update -> apply

This commit is contained in:
Mike J Innes 2019-01-28 13:59:23 +00:00
parent bf0b5c5cef
commit 0f2975d905
4 changed files with 22 additions and 22 deletions

View File

@ -4,7 +4,7 @@ using Flux: Params
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay)) check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
# legacy update rule # legacy update rule
updaterule(opt, ps) = () -> update!(opt, ps) updaterule(opt, ps) = () -> _update_params!(opt, ps)
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.) function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD) depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
@ -117,7 +117,7 @@ struct OldOptimiser
func func
end end
update!(opt::OldOptimiser, ps) = opt.func() _update_params!(opt::OldOptimiser, ps) = opt.func()
# Train function # Train function
function train!(loss, data, opt; cb = () -> ()) function train!(loss, data, opt; cb = () -> ())

View File

@ -18,7 +18,7 @@ end
Descent() = Descent(0.1) Descent() = Descent(0.1)
function update!(o::Descent, x, Δ) function apply!(o::Descent, x, Δ)
Δ .*= o.eta Δ .*= o.eta
end end
@ -35,7 +35,7 @@ end
Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict()) Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
function update!(o::Momentum, x, Δ) function apply!(o::Momentum, x, Δ)
η, ρ = o.eta, o.rho η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x) v = get!(o.velocity, x, zero(x))::typeof(x)
@. v = ρ * v - η * Δ @. v = ρ * v - η * Δ
@ -55,7 +55,7 @@ end
Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict()) Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
function update!(o::Nesterov, x, Δ) function apply!(o::Nesterov, x, Δ)
η, ρ = o.eta, o.rho η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x) v = get!(o.velocity, x, zero(x))::typeof(x)
d = @. ρ^2 * v - (1+ρ) * η * Δ d = @. ρ^2 * v - (1+ρ) * η * Δ
@ -78,7 +78,7 @@ end
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict()) RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
function update!(o::RMSProp, x, Δ) function apply!(o::RMSProp, x, Δ)
η, ρ = o.eta, o.rho η, ρ = o.eta, o.rho
acc = get!(o.acc, x, zero(x))::typeof(x) acc = get!(o.acc, x, zero(x))::typeof(x)
@. acc = ρ * acc + (1 - ρ) * Δ^2 @. acc = ρ * acc + (1 - ρ) * Δ^2
@ -98,7 +98,7 @@ end
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict()) ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
function update!(o::ADAM, x, Δ) function apply!(o::ADAM, x, Δ)
η, β = o.eta, o.beta η, β = o.eta, o.beta
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β)) mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β[1] * mt + (1 - β[1]) * Δ @. mt = β[1] * mt + (1 - β[1]) * Δ
@ -122,7 +122,7 @@ end
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict()) AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
function update!(o::AdaMax, x, Δ) function apply!(o::AdaMax, x, Δ)
η, β = o.eta, o.beta η, β = o.eta, o.beta
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β)) mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β[1] * mt + (1 - β[1]) * Δ @. mt = β[1] * mt + (1 - β[1]) * Δ
@ -145,7 +145,7 @@ end
ADAGrad(η = 0.1) = ADAGrad(η, IdDict()) ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
function update!(o::ADAGrad, x, Δ) function apply!(o::ADAGrad, x, Δ)
η = o.eta η = o.eta
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x) acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
@. acc += Δ^2 @. acc += Δ^2
@ -165,7 +165,7 @@ end
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict()) ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
function update!(o::ADADelta, x, Δ) function apply!(o::ADADelta, x, Δ)
ρ = o.rho ρ = o.rho
acc, Δacc = get!(o.state, x, (zero(x), zero(x))) acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
@. acc = ρ * acc + (1 - ρ) * Δ^2 @. acc = ρ * acc + (1 - ρ) * Δ^2
@ -188,7 +188,7 @@ end
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict()) AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
function update!(o::AMSGrad, x, Δ) function apply!(o::AMSGrad, x, Δ)
η, β = o.eta, o.beta η, β = o.eta, o.beta
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x)))) mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
@. mt = β[1] * mt + (1 - β[1]) * Δ @. mt = β[1] * mt + (1 - β[1]) * Δ
@ -211,7 +211,7 @@ end
NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict()) NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
function update!(o::NADAM, x, Δ) function apply!(o::NADAM, x, Δ)
η, β = o.eta, o.beta η, β = o.eta, o.beta
β1p, β2p = o.beta β1p, β2p = o.beta
mt, vt = get!(o.state, x, (zero(x), zero(x))) mt, vt = get!(o.state, x, (zero(x), zero(x)))
@ -250,9 +250,9 @@ Optimiser(o...) = Optimiser(Any[o...])
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...) Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
function update!(o::Optimiser, x, Δ) function apply!(o::Optimiser, x, Δ)
for opt in o.os for opt in o.os
Δ = update!(opt, x, Δ) Δ = apply!(opt, x, Δ)
end end
return Δ return Δ
end end
@ -272,7 +272,7 @@ end
InvDecay(γ = 0.001) = InvDecay(γ, IdDict()) InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
function update!(o::InvDecay, x, Δ) function apply!(o::InvDecay, x, Δ)
γ = o.gamma γ = o.gamma
n = get!(o.state, x, 1) n = get!(o.state, x, 1)
Δ .*= 1 / (1 + γ * n) Δ .*= 1 / (1 + γ * n)
@ -300,7 +300,7 @@ end
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict()) ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
function update!(o::ExpDecay, x, Δ) function apply!(o::ExpDecay, x, Δ)
η, s, decay = o.eta, o.step, o.decay η, s, decay = o.eta, o.step, o.decay
n = o.current[x] = get(o.current, x, 0) + 1 n = o.current[x] = get(o.current, x, 0) + 1
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1 if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
@ -321,7 +321,7 @@ end
WeightDecay() = WeightDecay(0) WeightDecay() = WeightDecay(0)
function update!(o::WeightDecay, x, Δ) function apply!(o::WeightDecay, x, Δ)
wd = o.wd wd = o.wd
@. Δ += wd * x @. Δ += wd * x
end end

View File

@ -2,9 +2,9 @@ using Juno
using Flux.Tracker: data, grad, back! using Flux.Tracker: data, grad, back!
import Base.depwarn import Base.depwarn
function update!(opt, xs) function _update_params!(opt, xs)
for x in xs for x in xs
Δ = update!(opt, x.data, x.grad) Δ = apply!(opt, x.data, x.grad)
x.data .-= Δ x.data .-= Δ
Δ .= 0 Δ .= 0
end end
@ -69,7 +69,7 @@ function train!(loss, ps, data, opt; cb = () -> ())
try try
l = loss(d...) l = loss(d...)
@interrupts back!(l) @interrupts back!(l)
update!(opt, ps) _update_params!(opt, ps)
if cb() == :stop if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop) depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break break

View File

@ -17,7 +17,7 @@ using Test
for t = 1: 10^5 for t = 1: 10^5
l = loss(rand(10)) l = loss(rand(10))
back!(l) back!(l)
delta = Optimise.update!(opt, w.data, w.grad) delta = Optimise.apply!(opt, w.data, w.grad)
w.data .-= delta w.data .-= delta
end end
@test Flux.mse(w, w) < 0.01 @test Flux.mse(w, w) < 0.01
@ -33,7 +33,7 @@ end
for t = 1:10^5 for t = 1:10^5
l = loss(rand(10)) l = loss(rand(10))
back!(l) back!(l)
delta = Optimise.update!(opt, w.data, w.grad) delta = Optimise.apply!(opt, w.data, w.grad)
w.data .-= delta w.data .-= delta
end end
@test Flux.mse(w, w) < 0.01 @test Flux.mse(w, w) < 0.01