update -> apply
This commit is contained in:
parent
bf0b5c5cef
commit
0f2975d905
@ -4,7 +4,7 @@ using Flux: Params
|
|||||||
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
|
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
|
||||||
|
|
||||||
# legacy update rule
|
# legacy update rule
|
||||||
updaterule(opt, ps) = () -> update!(opt, ps)
|
updaterule(opt, ps) = () -> _update_params!(opt, ps)
|
||||||
|
|
||||||
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
|
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
|
||||||
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
|
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
|
||||||
@ -117,7 +117,7 @@ struct OldOptimiser
|
|||||||
func
|
func
|
||||||
end
|
end
|
||||||
|
|
||||||
update!(opt::OldOptimiser, ps) = opt.func()
|
_update_params!(opt::OldOptimiser, ps) = opt.func()
|
||||||
|
|
||||||
# Train function
|
# Train function
|
||||||
function train!(loss, data, opt; cb = () -> ())
|
function train!(loss, data, opt; cb = () -> ())
|
||||||
|
@ -18,7 +18,7 @@ end
|
|||||||
|
|
||||||
Descent() = Descent(0.1)
|
Descent() = Descent(0.1)
|
||||||
|
|
||||||
function update!(o::Descent, x, Δ)
|
function apply!(o::Descent, x, Δ)
|
||||||
Δ .*= o.eta
|
Δ .*= o.eta
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ end
|
|||||||
|
|
||||||
Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
||||||
|
|
||||||
function update!(o::Momentum, x, Δ)
|
function apply!(o::Momentum, x, Δ)
|
||||||
η, ρ = o.eta, o.rho
|
η, ρ = o.eta, o.rho
|
||||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||||
@. v = ρ * v - η * Δ
|
@. v = ρ * v - η * Δ
|
||||||
@ -55,7 +55,7 @@ end
|
|||||||
|
|
||||||
Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
||||||
|
|
||||||
function update!(o::Nesterov, x, Δ)
|
function apply!(o::Nesterov, x, Δ)
|
||||||
η, ρ = o.eta, o.rho
|
η, ρ = o.eta, o.rho
|
||||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||||
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
||||||
@ -78,7 +78,7 @@ end
|
|||||||
|
|
||||||
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
||||||
|
|
||||||
function update!(o::RMSProp, x, Δ)
|
function apply!(o::RMSProp, x, Δ)
|
||||||
η, ρ = o.eta, o.rho
|
η, ρ = o.eta, o.rho
|
||||||
acc = get!(o.acc, x, zero(x))::typeof(x)
|
acc = get!(o.acc, x, zero(x))::typeof(x)
|
||||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||||
@ -98,7 +98,7 @@ end
|
|||||||
|
|
||||||
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
|
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
|
||||||
|
|
||||||
function update!(o::ADAM, x, Δ)
|
function apply!(o::ADAM, x, Δ)
|
||||||
η, β = o.eta, o.beta
|
η, β = o.eta, o.beta
|
||||||
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
|
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
|
||||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||||
@ -122,7 +122,7 @@ end
|
|||||||
|
|
||||||
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
|
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
|
||||||
|
|
||||||
function update!(o::AdaMax, x, Δ)
|
function apply!(o::AdaMax, x, Δ)
|
||||||
η, β = o.eta, o.beta
|
η, β = o.eta, o.beta
|
||||||
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
|
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
|
||||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||||
@ -145,7 +145,7 @@ end
|
|||||||
|
|
||||||
ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
||||||
|
|
||||||
function update!(o::ADAGrad, x, Δ)
|
function apply!(o::ADAGrad, x, Δ)
|
||||||
η = o.eta
|
η = o.eta
|
||||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
|
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
|
||||||
@. acc += Δ^2
|
@. acc += Δ^2
|
||||||
@ -165,7 +165,7 @@ end
|
|||||||
|
|
||||||
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
|
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
|
||||||
|
|
||||||
function update!(o::ADADelta, x, Δ)
|
function apply!(o::ADADelta, x, Δ)
|
||||||
ρ = o.rho
|
ρ = o.rho
|
||||||
acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
|
acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
|
||||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||||
@ -188,7 +188,7 @@ end
|
|||||||
|
|
||||||
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
|
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
|
||||||
|
|
||||||
function update!(o::AMSGrad, x, Δ)
|
function apply!(o::AMSGrad, x, Δ)
|
||||||
η, β = o.eta, o.beta
|
η, β = o.eta, o.beta
|
||||||
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
|
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
|
||||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||||
@ -211,7 +211,7 @@ end
|
|||||||
|
|
||||||
NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
|
NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
|
||||||
|
|
||||||
function update!(o::NADAM, x, Δ)
|
function apply!(o::NADAM, x, Δ)
|
||||||
η, β = o.eta, o.beta
|
η, β = o.eta, o.beta
|
||||||
β1p, β2p = o.beta
|
β1p, β2p = o.beta
|
||||||
mt, vt = get!(o.state, x, (zero(x), zero(x)))
|
mt, vt = get!(o.state, x, (zero(x), zero(x)))
|
||||||
@ -250,9 +250,9 @@ Optimiser(o...) = Optimiser(Any[o...])
|
|||||||
|
|
||||||
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
|
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
|
||||||
|
|
||||||
function update!(o::Optimiser, x, Δ)
|
function apply!(o::Optimiser, x, Δ)
|
||||||
for opt in o.os
|
for opt in o.os
|
||||||
Δ = update!(opt, x, Δ)
|
Δ = apply!(opt, x, Δ)
|
||||||
end
|
end
|
||||||
return Δ
|
return Δ
|
||||||
end
|
end
|
||||||
@ -272,7 +272,7 @@ end
|
|||||||
|
|
||||||
InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
|
InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
|
||||||
|
|
||||||
function update!(o::InvDecay, x, Δ)
|
function apply!(o::InvDecay, x, Δ)
|
||||||
γ = o.gamma
|
γ = o.gamma
|
||||||
n = get!(o.state, x, 1)
|
n = get!(o.state, x, 1)
|
||||||
Δ .*= 1 / (1 + γ * n)
|
Δ .*= 1 / (1 + γ * n)
|
||||||
@ -300,7 +300,7 @@ end
|
|||||||
|
|
||||||
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
|
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
|
||||||
|
|
||||||
function update!(o::ExpDecay, x, Δ)
|
function apply!(o::ExpDecay, x, Δ)
|
||||||
η, s, decay = o.eta, o.step, o.decay
|
η, s, decay = o.eta, o.step, o.decay
|
||||||
n = o.current[x] = get(o.current, x, 0) + 1
|
n = o.current[x] = get(o.current, x, 0) + 1
|
||||||
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
|
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
|
||||||
@ -321,7 +321,7 @@ end
|
|||||||
|
|
||||||
WeightDecay() = WeightDecay(0)
|
WeightDecay() = WeightDecay(0)
|
||||||
|
|
||||||
function update!(o::WeightDecay, x, Δ)
|
function apply!(o::WeightDecay, x, Δ)
|
||||||
wd = o.wd
|
wd = o.wd
|
||||||
@. Δ += wd * x
|
@. Δ += wd * x
|
||||||
end
|
end
|
||||||
|
@ -2,9 +2,9 @@ using Juno
|
|||||||
using Flux.Tracker: data, grad, back!
|
using Flux.Tracker: data, grad, back!
|
||||||
import Base.depwarn
|
import Base.depwarn
|
||||||
|
|
||||||
function update!(opt, xs)
|
function _update_params!(opt, xs)
|
||||||
for x in xs
|
for x in xs
|
||||||
Δ = update!(opt, x.data, x.grad)
|
Δ = apply!(opt, x.data, x.grad)
|
||||||
x.data .-= Δ
|
x.data .-= Δ
|
||||||
Δ .= 0
|
Δ .= 0
|
||||||
end
|
end
|
||||||
@ -69,7 +69,7 @@ function train!(loss, ps, data, opt; cb = () -> ())
|
|||||||
try
|
try
|
||||||
l = loss(d...)
|
l = loss(d...)
|
||||||
@interrupts back!(l)
|
@interrupts back!(l)
|
||||||
update!(opt, ps)
|
_update_params!(opt, ps)
|
||||||
if cb() == :stop
|
if cb() == :stop
|
||||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||||
break
|
break
|
||||||
|
@ -17,7 +17,7 @@ using Test
|
|||||||
for t = 1: 10^5
|
for t = 1: 10^5
|
||||||
l = loss(rand(10))
|
l = loss(rand(10))
|
||||||
back!(l)
|
back!(l)
|
||||||
delta = Optimise.update!(opt, w′.data, w′.grad)
|
delta = Optimise.apply!(opt, w′.data, w′.grad)
|
||||||
w′.data .-= delta
|
w′.data .-= delta
|
||||||
end
|
end
|
||||||
@test Flux.mse(w, w′) < 0.01
|
@test Flux.mse(w, w′) < 0.01
|
||||||
@ -33,7 +33,7 @@ end
|
|||||||
for t = 1:10^5
|
for t = 1:10^5
|
||||||
l = loss(rand(10))
|
l = loss(rand(10))
|
||||||
back!(l)
|
back!(l)
|
||||||
delta = Optimise.update!(opt, w′.data, w′.grad)
|
delta = Optimise.apply!(opt, w′.data, w′.grad)
|
||||||
w′.data .-= delta
|
w′.data .-= delta
|
||||||
end
|
end
|
||||||
@test Flux.mse(w, w′) < 0.01
|
@test Flux.mse(w, w′) < 0.01
|
||||||
|
Loading…
Reference in New Issue
Block a user