Merge pull request #379 from dhairyagandhi96/master
New optimisers interface
This commit is contained in:
commit
43c5f90d93
|
@ -4,7 +4,6 @@ os:
|
||||||
- linux
|
- linux
|
||||||
# - osx
|
# - osx
|
||||||
julia:
|
julia:
|
||||||
- 0.7
|
|
||||||
- 1.0
|
- 1.0
|
||||||
- nightly
|
- nightly
|
||||||
# uncomment the following lines to override the default test script
|
# uncomment the following lines to override the default test script
|
||||||
|
|
2
REQUIRE
2
REQUIRE
|
@ -1,4 +1,4 @@
|
||||||
julia 0.7
|
julia 1.0
|
||||||
Juno
|
Juno
|
||||||
MacroTools 0.3.3
|
MacroTools 0.3.3
|
||||||
NNlib
|
NNlib
|
||||||
|
|
|
@ -19,8 +19,9 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
||||||
include("optimise/Optimise.jl")
|
include("optimise/Optimise.jl")
|
||||||
using .Optimise
|
using .Optimise
|
||||||
using .Optimise: @epochs
|
using .Optimise: @epochs
|
||||||
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
|
||||||
|
ADAMW, InvDecay, ExpDecay, WeightDecay
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("onehot.jl")
|
include("onehot.jl")
|
||||||
|
|
|
@ -1,23 +1,12 @@
|
||||||
module Optimise
|
module Optimise
|
||||||
|
|
||||||
export train!,
|
export train!,
|
||||||
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
|
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
|
||||||
|
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
|
||||||
struct Param{T}
|
|
||||||
x::T
|
|
||||||
Δ::T
|
|
||||||
end
|
|
||||||
|
|
||||||
Param(x::AbstractArray) = Param(x, zero(x))
|
|
||||||
|
|
||||||
include("optimisers.jl")
|
include("optimisers.jl")
|
||||||
include("interface.jl")
|
|
||||||
include("train.jl")
|
include("train.jl")
|
||||||
|
include("deprecations.jl")
|
||||||
using Flux.Tracker: TrackedArray
|
|
||||||
|
|
||||||
Param(x::TrackedArray) = Param(x.data, x.grad)
|
|
||||||
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,125 @@
|
||||||
|
using Base: depwarn
|
||||||
|
|
||||||
|
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
|
||||||
|
|
||||||
|
# legacy update rule
|
||||||
|
updaterule(opt, ps) = () -> update!(p, ps)
|
||||||
|
|
||||||
|
function SGD(params::AbstractArray, η = 0.1; decay = 0.)
|
||||||
|
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
opt = Descent(η)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.)
|
||||||
|
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
opt = Momentum(η, ρ)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
|
||||||
|
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
opt = Nesterov(η, ρ)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
|
||||||
|
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
opt = RMSProp(η, ρ)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||||
|
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
β = (β1, β2)
|
||||||
|
opt = ADAM(η, β)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.)
|
||||||
|
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
opt = ADAGrad(η)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.)
|
||||||
|
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
opt = ADADelta(ρ)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||||
|
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
β = (β1, β2)
|
||||||
|
opt = AdaMax(η, β)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||||
|
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
β = (β1, β2)
|
||||||
|
opt = AMSGrad(η, β)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||||
|
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
β = (β1, β2)
|
||||||
|
opt = NADAM(η, β)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||||
|
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
β = (β1, β2)
|
||||||
|
opt = ADAMW(η, β)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
decay != 0 && (opt = Optimiser(opt, WeightDecay(decay)))
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
# Old training loop
|
||||||
|
|
||||||
|
struct OldOptimiser
|
||||||
|
func
|
||||||
|
end
|
||||||
|
|
||||||
|
update!(opt::OldOptimiser, ps) = opt.func()
|
||||||
|
|
||||||
|
# Train function
|
||||||
|
function train!(loss, data, opt; cb = () -> ())
|
||||||
|
depwarn("train!(loss, data, opt) is deprecated; use train!(loss, params, data, opt) instead", :train!)
|
||||||
|
train!(loss, (), data, OldOptimiser(opt); cb = cb)
|
||||||
|
end
|
|
@ -1,110 +0,0 @@
|
||||||
call(f, xs...) = f(xs...)
|
|
||||||
|
|
||||||
# note for optimisers: set to zero
|
|
||||||
# p.Δ at the end of the weights update
|
|
||||||
function optimiser(ps, fs...)
|
|
||||||
ps = [Param(p) for p in ps]
|
|
||||||
fs = map(ps) do p
|
|
||||||
os = map(f -> f(p), fs)
|
|
||||||
() -> foreach(call, os)
|
|
||||||
end
|
|
||||||
() -> foreach(call, fs)
|
|
||||||
end
|
|
||||||
|
|
||||||
"""
|
|
||||||
SGD(params, η = 0.1; decay = 0)
|
|
||||||
|
|
||||||
Classic gradient descent optimiser with learning rate `η`.
|
|
||||||
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
|
|
||||||
|
|
||||||
Supports inverse decaying learning rate if the `decay` argument is provided.
|
|
||||||
"""
|
|
||||||
SGD(ps, η = 0.1; decay = 0) =
|
|
||||||
optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η))
|
|
||||||
|
|
||||||
"""
|
|
||||||
Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
|
|
||||||
|
|
||||||
SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay.
|
|
||||||
"""
|
|
||||||
Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) =
|
|
||||||
optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1))
|
|
||||||
|
|
||||||
"""
|
|
||||||
Nesterov(params, η = 0.01; ρ = 0.9, decay = 0)
|
|
||||||
|
|
||||||
SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay.
|
|
||||||
"""
|
|
||||||
Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) =
|
|
||||||
optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1))
|
|
||||||
|
|
||||||
"""
|
|
||||||
RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0)
|
|
||||||
|
|
||||||
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
|
||||||
optimiser. Parameters other than learning rate don't need tuning. Often a good
|
|
||||||
choice for recurrent networks.
|
|
||||||
"""
|
|
||||||
RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
|
|
||||||
optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
|
||||||
|
|
||||||
"""
|
|
||||||
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
|
||||||
|
|
||||||
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
|
||||||
"""
|
|
||||||
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
|
||||||
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
|
||||||
|
|
||||||
"""
|
|
||||||
ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
|
||||||
|
|
||||||
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
|
||||||
"""
|
|
||||||
ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
|
||||||
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay))
|
|
||||||
|
|
||||||
"""
|
|
||||||
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
|
||||||
|
|
||||||
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
|
|
||||||
the ∞-norm.
|
|
||||||
"""
|
|
||||||
AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
|
||||||
optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
|
||||||
|
|
||||||
"""
|
|
||||||
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
|
|
||||||
|
|
||||||
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
|
|
||||||
Parameters don't need tuning.
|
|
||||||
"""
|
|
||||||
ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) =
|
|
||||||
optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
|
||||||
|
|
||||||
"""
|
|
||||||
ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
|
|
||||||
|
|
||||||
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
|
|
||||||
tuning.
|
|
||||||
"""
|
|
||||||
ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) =
|
|
||||||
optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1))
|
|
||||||
|
|
||||||
"""
|
|
||||||
AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
|
||||||
|
|
||||||
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
|
|
||||||
tuning.
|
|
||||||
"""
|
|
||||||
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
|
||||||
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
|
||||||
|
|
||||||
"""
|
|
||||||
NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
|
||||||
|
|
||||||
[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
|
|
||||||
than learning rate don't need tuning.
|
|
||||||
"""
|
|
||||||
NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
|
||||||
optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
|
|
@ -1,130 +1,304 @@
|
||||||
function descent(p::Param, η::Real)
|
using Flux
|
||||||
function ()
|
using Base: @get!
|
||||||
@. p.x -= η * p.Δ
|
using MacroTools: @forward
|
||||||
@. p.Δ = 0
|
|
||||||
|
const ϵ = 1e-8
|
||||||
|
|
||||||
|
# TODO: should use weak refs
|
||||||
|
|
||||||
|
"""
|
||||||
|
Descent(η)
|
||||||
|
|
||||||
|
Classic gradient descent optimiser with learning rate `η`.
|
||||||
|
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
|
||||||
|
"""
|
||||||
|
mutable struct Descent
|
||||||
|
eta::Float64
|
||||||
|
end
|
||||||
|
|
||||||
|
Descent() = Descent(0.1)
|
||||||
|
|
||||||
|
function update!(o::Descent, x, Δ)
|
||||||
|
Δ .*= o.eta
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
Momentum(params, η = 0.01; ρ = 0.9)
|
||||||
|
|
||||||
|
Gradient descent with learning rate `η` and momentum `ρ`.
|
||||||
|
"""
|
||||||
|
mutable struct Momentum
|
||||||
|
eta::Float64
|
||||||
|
rho::Float64
|
||||||
|
velocity::IdDict
|
||||||
|
end
|
||||||
|
|
||||||
|
Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
||||||
|
|
||||||
|
function update!(o::Momentum, x, Δ)
|
||||||
|
η, ρ = o.eta, o.rho
|
||||||
|
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||||
|
@. v = ρ * v - η * Δ
|
||||||
|
@. Δ = -v
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
Nesterov(eta, ρ = 0.9)
|
||||||
|
|
||||||
|
Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
|
||||||
|
"""
|
||||||
|
mutable struct Nesterov
|
||||||
|
eta::Float64
|
||||||
|
rho::Float64
|
||||||
|
velocity::IdDict
|
||||||
|
end
|
||||||
|
|
||||||
|
Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
||||||
|
|
||||||
|
function update!(o::Nesterov, x, Δ)
|
||||||
|
η, ρ = o.eta, o.rho
|
||||||
|
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||||
|
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
||||||
|
@. v = ρ*v - η*Δ
|
||||||
|
@. Δ = -d
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
RMSProp(η = 0.001, ρ = 0.9)
|
||||||
|
|
||||||
|
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
||||||
|
optimiser. Parameters other than learning rate don't need tuning. Often a good
|
||||||
|
choice for recurrent networks.
|
||||||
|
"""
|
||||||
|
mutable struct RMSProp
|
||||||
|
eta::Float64
|
||||||
|
rho::Float64
|
||||||
|
acc::IdDict
|
||||||
|
end
|
||||||
|
|
||||||
|
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
||||||
|
|
||||||
|
function update!(o::RMSProp, x, Δ)
|
||||||
|
η, ρ = o.eta, o.rho
|
||||||
|
acc = get!(o.acc, x, zero(x))::typeof(x)
|
||||||
|
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||||
|
@. Δ *= η / (√acc + ϵ)
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
ADAM(η = 0.001, β = (0.9, 0.999))
|
||||||
|
|
||||||
|
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
||||||
|
"""
|
||||||
|
mutable struct ADAM
|
||||||
|
eta::Float64
|
||||||
|
beta::Tuple{Float64,Float64}
|
||||||
|
state::IdDict
|
||||||
|
end
|
||||||
|
|
||||||
|
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
|
||||||
|
|
||||||
|
function update!(o::ADAM, x, Δ)
|
||||||
|
η, β = o.eta, o.beta
|
||||||
|
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
|
||||||
|
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||||
|
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
||||||
|
@. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η
|
||||||
|
o.state[x] = (mt, vt, βp .* β)
|
||||||
|
return Δ
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)
|
||||||
|
|
||||||
|
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
|
||||||
|
the ∞-norm.
|
||||||
|
"""
|
||||||
|
mutable struct AdaMax
|
||||||
|
eta::Float64
|
||||||
|
beta::Tuple{Float64,Float64}
|
||||||
|
state::IdDict
|
||||||
|
end
|
||||||
|
|
||||||
|
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
|
||||||
|
|
||||||
|
function update!(o::AdaMax, x, Δ)
|
||||||
|
η, β = o.eta, o.beta
|
||||||
|
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
|
||||||
|
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||||
|
@. ut = max(β[2] * ut, abs(Δ))
|
||||||
|
@. Δ = (η/(1 - βp[1])) * mt/(ut + ϵ)
|
||||||
|
o.state[x] = (mt, ut, βp .* β)
|
||||||
|
return Δ
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
ADAGrad(η = 0.1; ϵ = 1e-8)
|
||||||
|
|
||||||
|
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
|
||||||
|
Parameters don't need tuning.
|
||||||
|
"""
|
||||||
|
mutable struct ADAGrad
|
||||||
|
eta::Float64
|
||||||
|
acc::IdDict
|
||||||
|
end
|
||||||
|
|
||||||
|
ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
||||||
|
|
||||||
|
function update!(o::ADAGrad, x, Δ)
|
||||||
|
η = o.eta
|
||||||
|
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
|
||||||
|
@. acc += Δ^2
|
||||||
|
@. Δ *= η / √(acc + ϵ)
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
ADADelta(ρ = 0.9, ϵ = 1e-8)
|
||||||
|
|
||||||
|
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
|
||||||
|
tuning.
|
||||||
|
"""
|
||||||
|
mutable struct ADADelta
|
||||||
|
rho::Float64
|
||||||
|
state::IdDict
|
||||||
|
end
|
||||||
|
|
||||||
|
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
|
||||||
|
|
||||||
|
function update!(o::ADADelta, x, Δ)
|
||||||
|
ρ = o.rho
|
||||||
|
acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
|
||||||
|
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||||
|
@. Δ *= √(Δacc + ϵ) / √(acc + ϵ)
|
||||||
|
@. Δacc = ρ * Δacc + (1 - ρ) * Δ^2
|
||||||
|
return Δ
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
AMSGrad(η = 0.001, β = (0.9, 0.999))
|
||||||
|
|
||||||
|
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
|
||||||
|
tuning.
|
||||||
|
"""
|
||||||
|
mutable struct AMSGrad
|
||||||
|
eta::Float64
|
||||||
|
beta::Tuple{Float64, Float64}
|
||||||
|
state::IdDict
|
||||||
|
end
|
||||||
|
|
||||||
|
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
|
||||||
|
|
||||||
|
function update!(o::AMSGrad, x, Δ)
|
||||||
|
η, β = o.eta, o.beta
|
||||||
|
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
|
||||||
|
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||||
|
@. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2
|
||||||
|
@. v̂t = max.(v̂t, vt)
|
||||||
|
@. Δ = η * mt / √v̂t
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
NADAM(η = 0.001, β = (0.9, 0.999))
|
||||||
|
|
||||||
|
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser. Parameters don't need
|
||||||
|
tuning.
|
||||||
|
"""
|
||||||
|
mutable struct NADAM
|
||||||
|
eta::Float64
|
||||||
|
beta::Tuple{Float64, Float64}
|
||||||
|
state::IdDict
|
||||||
|
end
|
||||||
|
|
||||||
|
NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
|
||||||
|
|
||||||
|
function update!(o::NADAM, x, Δ)
|
||||||
|
η, β = o.eta, o.beta
|
||||||
|
β1p, β2p = o.beta
|
||||||
|
mt, vt = get!(o.state, x, (zero(x), zero(x)))
|
||||||
|
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||||
|
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
||||||
|
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / √(vt * β[2] / (1 - β2p) + ϵ) * η
|
||||||
|
o.state[x] = (mt, vt, (β1p * β[1], β2p * β[2]))
|
||||||
|
return Δ
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
ADAMW((η = 0.001, β = (0.9, 0.999), decay = 0)
|
||||||
|
|
||||||
|
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
||||||
|
"""
|
||||||
|
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
|
||||||
|
Optimiser(ADAM(η, β), WeightDecay(wd))
|
||||||
|
|
||||||
|
# Compose optimizers
|
||||||
|
|
||||||
|
"""
|
||||||
|
Optimiser(a, b, c...)
|
||||||
|
|
||||||
|
Combine several optimisers into one; each optimiser produces a modified gradient
|
||||||
|
that will be fed into the next, and this is finally applied to the parameter as
|
||||||
|
usual.
|
||||||
|
"""
|
||||||
|
mutable struct Optimiser
|
||||||
|
os::Vector{Any}
|
||||||
|
end
|
||||||
|
|
||||||
|
Optimiser(o...) = Optimiser(Any[o...])
|
||||||
|
|
||||||
|
@forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
|
||||||
|
@forward Optimiser.os Base.iterate
|
||||||
|
|
||||||
|
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
|
||||||
|
|
||||||
|
function update!(o::Optimiser, x, Δ)
|
||||||
|
for opt in o.os
|
||||||
|
Δ = update!(opt, x, Δ)
|
||||||
end
|
end
|
||||||
|
return Δ
|
||||||
end
|
end
|
||||||
|
|
||||||
# Ref: https://arxiv.org/abs/1711.05101.pdf
|
mutable struct InvDecay
|
||||||
function descentweightdecay(p::Param, η::Real, γ::Real)
|
gamma::Float64
|
||||||
function ()
|
state::IdDict
|
||||||
@. p.x = p.x - η * (p.Δ + γ * p.x)
|
end
|
||||||
@. p.Δ = 0
|
|
||||||
|
InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
|
||||||
|
|
||||||
|
function update!(o::InvDecay, x, Δ)
|
||||||
|
γ = o.gamma
|
||||||
|
n = get!(o.state, x, 1)
|
||||||
|
Δ .*= 1 / (1 + γ * n)
|
||||||
|
o.state[x] = n + 1
|
||||||
|
return Δ
|
||||||
|
end
|
||||||
|
|
||||||
|
mutable struct ExpDecay
|
||||||
|
eta::Float64
|
||||||
|
decay::Float64
|
||||||
|
step::Int64
|
||||||
|
clip::Float64
|
||||||
|
current::IdDict
|
||||||
|
end
|
||||||
|
|
||||||
|
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
|
||||||
|
|
||||||
|
function update!(o::ExpDecay, x, Δ)
|
||||||
|
η, s, decay = o.eta, o.step, o.decay
|
||||||
|
n = o.current[x] = get(o.current, x, 0) + 1
|
||||||
|
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
|
||||||
|
η = max(η * decay^(s / n), o.clip)
|
||||||
|
o.eta = η
|
||||||
end
|
end
|
||||||
|
@. Δ *= decay
|
||||||
end
|
end
|
||||||
|
|
||||||
function momentum(p::Param, ρ, η)
|
mutable struct WeightDecay
|
||||||
v = zero(p.x)
|
wd::Real
|
||||||
function ()
|
|
||||||
@. v = ρ * v - η * p.Δ
|
|
||||||
@. p.Δ = -v
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
# Ref. https://arxiv.org/pdf/1212.0901.pdf
|
WeightDecay() = WeightDecay(0)
|
||||||
function nesterov(p::Param, ρ, η)
|
|
||||||
v = zero(p.x)
|
|
||||||
function ()
|
|
||||||
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
|
|
||||||
@. v = ρ*v - η*p.Δ
|
|
||||||
@. p.Δ = -d
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
function update!(o::WeightDecay, x, Δ)
|
||||||
acc = zero(p.x)
|
wd = o.wd
|
||||||
function ()
|
@. Δ += wd * x
|
||||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
|
||||||
@. p.Δ *= η / √(acc + ϵ)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
|
||||||
acc = zero(p.x) .+ ϵ
|
|
||||||
function ()
|
|
||||||
@. acc += p.Δ^2
|
|
||||||
@. p.Δ *= η / √(acc + ϵ)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
|
|
||||||
acc = zero(p.x)
|
|
||||||
Δacc = zero(p.x)
|
|
||||||
function ()
|
|
||||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
|
||||||
@. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ)
|
|
||||||
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
|
||||||
mt = zero(p.x)
|
|
||||||
vt = zero(p.x)
|
|
||||||
β1p, β2p = β1, β2
|
|
||||||
function ()
|
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
|
||||||
@. vt = β2 * vt + (1 - β2) * p.Δ^2
|
|
||||||
@. p.Δ = mt / (1 - β1p) / √(vt / (1 - β2p) + ϵ) * η
|
|
||||||
β1p *= β1
|
|
||||||
β2p *= β2
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
|
||||||
mt = zero(p.x)
|
|
||||||
ut = zero(p.x)
|
|
||||||
β1p = β1
|
|
||||||
function ()
|
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
|
||||||
@. ut = max(β2 * ut, abs(p.Δ))
|
|
||||||
@. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ)
|
|
||||||
β1p *= β1
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
|
||||||
mt = zero(p.x)
|
|
||||||
vt = zero(p.x) .+ ϵ
|
|
||||||
v̂t = zero(p.x) .+ ϵ
|
|
||||||
function ()
|
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
|
||||||
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
|
|
||||||
@. v̂t = max.(v̂t, vt)
|
|
||||||
@. p.Δ = η * mt / √v̂t
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
|
||||||
mt = zero(p.x)
|
|
||||||
vt = zero(p.x)
|
|
||||||
β1p, β2p = β1, β2
|
|
||||||
function ()
|
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
|
||||||
@. vt = β2 * vt + (1 - β2) * p.Δ^2
|
|
||||||
@. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / √(vt * β2 / (1 - β2p) + ϵ) * η
|
|
||||||
β1p *= β1
|
|
||||||
β2p *= β2
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
|
|
||||||
|
|
||||||
function expdecay(p::Param, γ::Real)
|
|
||||||
if γ != 0
|
|
||||||
return () -> p.Δ .+= γ .* p.x
|
|
||||||
else
|
|
||||||
return () -> nothing
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function invdecay(p::Param, γ::Real)
|
|
||||||
if γ != 0
|
|
||||||
n = 0
|
|
||||||
return () -> begin
|
|
||||||
p.Δ .*= 1 / (1 + γ * n)
|
|
||||||
n += 1
|
|
||||||
end
|
|
||||||
else
|
|
||||||
return () -> nothing
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,7 +1,16 @@
|
||||||
using Juno
|
using Juno
|
||||||
using Flux.Tracker: back!
|
using Flux.Tracker: data, grad, back!
|
||||||
import Base.depwarn
|
import Base.depwarn
|
||||||
|
|
||||||
|
function update!(opt, xs)
|
||||||
|
for x in xs
|
||||||
|
Δ = update!(opt, x.data, x.grad)
|
||||||
|
x.data .-= Δ
|
||||||
|
Δ .= 0
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# Callback niceties
|
||||||
runall(f) = f
|
runall(f) = f
|
||||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||||
|
|
||||||
|
@ -35,7 +44,7 @@ function stop()
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
train!(loss, data, opt)
|
train!(model, loss, data, opt)
|
||||||
|
|
||||||
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
|
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
|
||||||
backpropagation and calls the optimizer `opt`.
|
backpropagation and calls the optimizer `opt`.
|
||||||
|
@ -44,7 +53,7 @@ Takes a callback as keyword argument `cb`. For example, this will print "trainin
|
||||||
every 10 seconds:
|
every 10 seconds:
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
Flux.train!(loss, data, opt,
|
Flux.train!(model, loss, data, opt,
|
||||||
cb = throttle(() -> println("training"), 10))
|
cb = throttle(() -> println("training"), 10))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -52,14 +61,14 @@ The callback can return `:stop` to interrupt the training loop.
|
||||||
|
|
||||||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||||
"""
|
"""
|
||||||
function train!(loss, data, opt; cb = () -> ())
|
function train!(loss, ps, data, opt; cb = () -> ())
|
||||||
cb = runall(cb)
|
cb = runall(cb)
|
||||||
opt = runall(opt)
|
opt = runall(opt)
|
||||||
@progress for d in data
|
@progress for d in data
|
||||||
try
|
try
|
||||||
l = loss(d...)
|
l = loss(d...)
|
||||||
@interrupts back!(l)
|
@interrupts back!(l)
|
||||||
opt()
|
update!(opt, ps)
|
||||||
if cb() == :stop
|
if cb() == :stop
|
||||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||||
break
|
break
|
||||||
|
|
|
@ -69,15 +69,28 @@ end
|
||||||
# Out-of-place gradients
|
# Out-of-place gradients
|
||||||
|
|
||||||
struct Params
|
struct Params
|
||||||
params::IdSet
|
order::Vector{Any}
|
||||||
Params(xs) = new(IdSet(xs))
|
params::IdSet{Any}
|
||||||
|
Params() = new([], IdSet())
|
||||||
end
|
end
|
||||||
|
|
||||||
@forward Params.params Base.iterate, Base.length
|
@forward Params.order Base.iterate, Base.length
|
||||||
|
|
||||||
|
function Base.push!(ps::Params, x)
|
||||||
|
if !(x in ps.params)
|
||||||
|
push!(ps.order, x)
|
||||||
|
push!(ps.params, x)
|
||||||
|
end
|
||||||
|
return ps
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
|
||||||
|
|
||||||
|
Params(xs) = push!(Params(), xs...)
|
||||||
|
|
||||||
function Base.show(io::IO, ps::Params)
|
function Base.show(io::IO, ps::Params)
|
||||||
print(io, "Params([")
|
print(io, "Params([")
|
||||||
join(io, ps.params, ", ")
|
join(io, ps.order, ", ")
|
||||||
print(io, "])")
|
print(io, "])")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ Base.eltype(::IdSet{T}) where T = T
|
||||||
|
|
||||||
IdSet() = IdSet{Any}()
|
IdSet() = IdSet{Any}()
|
||||||
|
|
||||||
|
Base.push!(s::IdSet) = s
|
||||||
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
|
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
|
||||||
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
|
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
|
||||||
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
||||||
|
|
|
@ -40,7 +40,7 @@ function prefor(f, x; seen = IdSet())
|
||||||
end
|
end
|
||||||
|
|
||||||
function params(m)
|
function params(m)
|
||||||
ps = []
|
ps = Params()
|
||||||
prefor(p ->
|
prefor(p ->
|
||||||
Tracker.istracked(p) && Tracker.isleaf(p) &&
|
Tracker.istracked(p) && Tracker.isleaf(p) &&
|
||||||
!any(p′ -> p′ === p, ps) && push!(ps, p),
|
!any(p′ -> p′ === p, ps) && push!(ps, p),
|
||||||
|
|
|
@ -3,14 +3,37 @@ using Flux.Tracker
|
||||||
using Test
|
using Test
|
||||||
@testset "Optimise" begin
|
@testset "Optimise" begin
|
||||||
w = randn(10, 10)
|
w = randn(10, 10)
|
||||||
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
|
@testset for Opt in [ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
|
||||||
w′ = param(randn(10, 10))
|
w′ = param(randn(10, 10))
|
||||||
loss(x) = Flux.mse(w*x, w′*x)
|
loss(x) = Flux.mse(w*x, w′*x)
|
||||||
opt = Opt([w′])
|
opt = Opt(0.001)
|
||||||
for t=1:10^5
|
if opt isa Descent || opt isa ADAGrad
|
||||||
|
opt = Opt(0.1)
|
||||||
|
end
|
||||||
|
if opt isa ADADelta
|
||||||
|
opt = Opt(0.9)
|
||||||
|
end
|
||||||
|
for t = 1: 10^5
|
||||||
l = loss(rand(10))
|
l = loss(rand(10))
|
||||||
back!(l)
|
back!(l)
|
||||||
opt()
|
delta = Optimise.update!(opt, w′.data, w′.grad)
|
||||||
|
w′.data .-= delta
|
||||||
|
end
|
||||||
|
@test Flux.mse(w, w′) < 0.01
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "Optimiser" begin
|
||||||
|
w = randn(10, 10)
|
||||||
|
@testset for Opt in [InvDecay, WeightDecay, ExpDecay]
|
||||||
|
w′ = param(randn(10, 10))
|
||||||
|
loss(x) = Flux.mse(w*x, w′*x)
|
||||||
|
opt = Optimiser(Opt(), ADAM(0.001))
|
||||||
|
for t = 1:10^5
|
||||||
|
l = loss(rand(10))
|
||||||
|
back!(l)
|
||||||
|
delta = Optimise.update!(opt, w′.data, w′.grad)
|
||||||
|
w′.data .-= delta
|
||||||
end
|
end
|
||||||
@test Flux.mse(w, w′) < 0.01
|
@test Flux.mse(w, w′) < 0.01
|
||||||
end
|
end
|
||||||
|
@ -21,9 +44,10 @@ end
|
||||||
l = param(1)
|
l = param(1)
|
||||||
|
|
||||||
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||||
|
(),
|
||||||
Iterators.repeated((), 100),
|
Iterators.repeated((), 100),
|
||||||
()->(),
|
Descent(),
|
||||||
cb = Flux.throttle(() -> (i > 3 && stop()), 1))
|
cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
|
||||||
|
|
||||||
@test 3 < i < 50
|
@test 3 < i < 50
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue