added deprecations and compose
This commit is contained in:
parent
87c7e65a2d
commit
b661db3797
@ -2,10 +2,11 @@ module Optimise
|
|||||||
|
|
||||||
export train!,
|
export train!,
|
||||||
Descent, ADAM, Momentum, Nesterov, RMSProp,
|
Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
|
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
|
||||||
InvDecay, ExpDecay, stop, StopException, Compose
|
InvDecay, ExpDecay, stop, StopException, Compose
|
||||||
|
|
||||||
include("optimisers.jl")
|
include("optimisers.jl")
|
||||||
include("train.jl")
|
include("train.jl")
|
||||||
|
include("deprecations.jl")
|
||||||
|
|
||||||
end
|
end
|
128
src/optimise/deprecations.jl
Normal file
128
src/optimise/deprecations.jl
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
using Base: depwarn
|
||||||
|
|
||||||
|
function check_decay(opt, decay)
|
||||||
|
if decay == 0.
|
||||||
|
opt = opt
|
||||||
|
else
|
||||||
|
if opt isa ADAMW
|
||||||
|
opt = Compose(opt, DescentWeightDecay(1, decay))
|
||||||
|
else
|
||||||
|
opt = Compose(opt, InvDecay(decay))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
opt
|
||||||
|
end
|
||||||
|
|
||||||
|
# legacy update rule
|
||||||
|
function updaterule(opt, ps)
|
||||||
|
() -> begin
|
||||||
|
for p in ps
|
||||||
|
delta = update!(opt, p.data, p.grad)
|
||||||
|
p.data .-= delta
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function Descent(params::AbstractArray, η = 0.1; decay = 0.)
|
||||||
|
depwarn("Descent(ps::Param) is deprecated; use Descent(η::Float64) instead", :Descent)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
opt = Descent(η)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.)
|
||||||
|
depwarn("Momentum(ps::Param) is deprecated; use Momentum(η::Float64) instead", :Momentum)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
opt = Momentum(η, ρ)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
|
||||||
|
depwarn("Nesterov(ps::Param) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
opt = Nesterov(η, ρ)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
|
||||||
|
depwarn("RMSProp(ps::Param) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
opt = RMSProp(η, ρ)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||||
|
depwarn("ADAM(ps::Param) is deprecated; use ADAM(η::Float64) instead", :ADAM)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
β = (β1, β2)
|
||||||
|
opt = ADAM(η, β)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.)
|
||||||
|
depwarn("ADAGrad(ps::Param) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
opt = ADAGrad(η)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.)
|
||||||
|
depwarn("ADADelta(ps::Param) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
opt = ADADelta(ρ)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||||
|
depwarn("AdaMax(ps::Param) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
β = (β1, β2)
|
||||||
|
opt = AdaMax(η, β)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||||
|
depwarn("AMSGrad(ps::Param) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
β = (β1, β2)
|
||||||
|
opt = AMSGrad(η, β)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||||
|
depwarn("NADAM(ps::Param) is deprecated; use NADAM(η::Float64) instead", :NADAM)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
β = (β1, β2)
|
||||||
|
opt = NADAM(η, β)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
||||||
|
|
||||||
|
function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||||
|
depwarn("ADAMW(ps::Param) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
|
||||||
|
|
||||||
|
ps = params
|
||||||
|
β = (β1, β2)
|
||||||
|
opt = ADAMW(η, β)
|
||||||
|
opt = check_decay(opt, decay)
|
||||||
|
updaterule(opt, ps)
|
||||||
|
end
|
@ -1,5 +1,6 @@
|
|||||||
using Flux
|
using Flux
|
||||||
using Base: @get!
|
using Base: @get!
|
||||||
|
using MacroTools: @forward
|
||||||
|
|
||||||
const ϵ = 1e-8
|
const ϵ = 1e-8
|
||||||
|
|
||||||
@ -15,6 +16,7 @@ mutable struct Descent
|
|||||||
eta::Float64
|
eta::Float64
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Descent(η = 0.1) = Descent(η)
|
||||||
function update!(o::Descent, x, Δ)
|
function update!(o::Descent, x, Δ)
|
||||||
Δ .*= o.eta
|
Δ .*= o.eta
|
||||||
end
|
end
|
||||||
@ -30,7 +32,7 @@ mutable struct Momentum
|
|||||||
velocity::IdDict
|
velocity::IdDict
|
||||||
end
|
end
|
||||||
|
|
||||||
Momentum(η, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
||||||
|
|
||||||
function update!(o::Momentum, x, Δ)
|
function update!(o::Momentum, x, Δ)
|
||||||
η, ρ = o.eta, o.rho
|
η, ρ = o.eta, o.rho
|
||||||
@ -50,7 +52,7 @@ mutable struct Nesterov
|
|||||||
velocity::IdDict
|
velocity::IdDict
|
||||||
end
|
end
|
||||||
|
|
||||||
Nesterov(η, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
||||||
|
|
||||||
function update!(o::Nesterov, x, Δ)
|
function update!(o::Nesterov, x, Δ)
|
||||||
η, ρ = o.eta, o.rho
|
η, ρ = o.eta, o.rho
|
||||||
@ -219,10 +221,46 @@ function update!(o::NADAM, x, Δ)
|
|||||||
return Δ
|
return Δ
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||||
|
|
||||||
|
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
||||||
|
"""
|
||||||
|
ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Compose(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, γ_decay))
|
||||||
|
|
||||||
|
# Compose optimizers
|
||||||
|
|
||||||
|
"""
|
||||||
|
`Compose(Compose(...), ...)`
|
||||||
|
|
||||||
|
Compose optimizers to support inbuilt or custom gradient updates while fitting the loss.
|
||||||
|
|
||||||
|
Example:\n\n
|
||||||
|
`Compose(ADAM(), Compose(RMSProp(0.001), ExpDecay(0.02)))`
|
||||||
|
"""
|
||||||
mutable struct Compose
|
mutable struct Compose
|
||||||
os::Vector{Any}
|
os::Vector{Any}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Compose(o...) = Compose(flattenCompose(o...))
|
||||||
|
|
||||||
|
@forward Compose.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
|
||||||
|
@forward Compose.os Base.iterate
|
||||||
|
|
||||||
|
Base.getindex(c::Compose, i::AbstractArray) = Compose(c.os[i]...)
|
||||||
|
|
||||||
|
function flattenCompose(o...)
|
||||||
|
res = []
|
||||||
|
for opt in o
|
||||||
|
if opt isa Compose
|
||||||
|
push!(res, flattenCompose(opt.os...)...)
|
||||||
|
else
|
||||||
|
push!(res, opt)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return res
|
||||||
|
end
|
||||||
|
|
||||||
function update!(o::Compose, x, Δ)
|
function update!(o::Compose, x, Δ)
|
||||||
for opt in o.os
|
for opt in o.os
|
||||||
Δ = update!(opt, x, Δ)
|
Δ = update!(opt, x, Δ)
|
||||||
@ -256,3 +294,15 @@ function update!(o::ExpDecay, x, Δ)
|
|||||||
γ = o.gamma
|
γ = o.gamma
|
||||||
@. Δ += γ * x
|
@. Δ += γ * x
|
||||||
end
|
end
|
||||||
|
|
||||||
|
mutable struct DescentWeightDecay
|
||||||
|
eta::Real
|
||||||
|
gamma::Real
|
||||||
|
end
|
||||||
|
|
||||||
|
DescentWeightDecay(η = 1, γ = 0) = DescentWeightDecay(η, γ)
|
||||||
|
function update!(o::DescentWeightDecay, x, Δ)
|
||||||
|
η, γ = o.eta, o.gamma
|
||||||
|
@. x = x - η * (Δ + γ * x)
|
||||||
|
Δ
|
||||||
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user