added deprecations and compose

This commit is contained in:
Dhairya Gandhi 2018-10-01 05:30:53 +05:30
parent 87c7e65a2d
commit b661db3797
3 changed files with 182 additions and 3 deletions

View File

@ -2,10 +2,11 @@ module Optimise
export train!, export train!,
Descent, ADAM, Momentum, Nesterov, RMSProp, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
InvDecay, ExpDecay, stop, StopException, Compose InvDecay, ExpDecay, stop, StopException, Compose
include("optimisers.jl") include("optimisers.jl")
include("train.jl") include("train.jl")
include("deprecations.jl")
end end

View File

@ -0,0 +1,128 @@
using Base: depwarn
function check_decay(opt, decay)
if decay == 0.
opt = opt
else
if opt isa ADAMW
opt = Compose(opt, DescentWeightDecay(1, decay))
else
opt = Compose(opt, InvDecay(decay))
end
end
opt
end
# legacy update rule
function updaterule(opt, ps)
() -> begin
for p in ps
delta = update!(opt, p.data, p.grad)
p.data .-= delta
end
end
end
function Descent(params::AbstractArray, η = 0.1; decay = 0.)
depwarn("Descent(ps::Param) is deprecated; use Descent(η::Float64) instead", :Descent)
ps = params
opt = Descent(η)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.)
depwarn("Momentum(ps::Param) is deprecated; use Momentum(η::Float64) instead", :Momentum)
ps = params
opt = Momentum(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("Nesterov(ps::Param) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
ps = params
opt = Nesterov(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("RMSProp(ps::Param) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
ps = params
opt = RMSProp(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAM(ps::Param) is deprecated; use ADAM(η::Float64) instead", :ADAM)
ps = params
β = (β1, β2)
opt = ADAM(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.)
depwarn("ADAGrad(ps::Param) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
ps = params
opt = ADAGrad(η)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.)
depwarn("ADADelta(ps::Param) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
ps = params
opt = ADADelta(ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AdaMax(ps::Param) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
ps = params
β = (β1, β2)
opt = AdaMax(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AMSGrad(ps::Param) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
ps = params
β = (β1, β2)
opt = AMSGrad(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("NADAM(ps::Param) is deprecated; use NADAM(η::Float64) instead", :NADAM)
ps = params
β = (β1, β2)
opt = NADAM(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAMW(ps::Param) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
ps = params
β = (β1, β2)
opt = ADAMW(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end

View File

@ -1,5 +1,6 @@
using Flux using Flux
using Base: @get! using Base: @get!
using MacroTools: @forward
const ϵ = 1e-8 const ϵ = 1e-8
@ -15,6 +16,7 @@ mutable struct Descent
eta::Float64 eta::Float64
end end
Descent(η = 0.1) = Descent(η)
function update!(o::Descent, x, Δ) function update!(o::Descent, x, Δ)
Δ .*= o.eta Δ .*= o.eta
end end
@ -30,7 +32,7 @@ mutable struct Momentum
velocity::IdDict velocity::IdDict
end end
Momentum(η, ρ = 0.9) = Momentum(η, ρ, IdDict()) Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
function update!(o::Momentum, x, Δ) function update!(o::Momentum, x, Δ)
η, ρ = o.eta, o.rho η, ρ = o.eta, o.rho
@ -50,7 +52,7 @@ mutable struct Nesterov
velocity::IdDict velocity::IdDict
end end
Nesterov(η, ρ = 0.9) = Nesterov(η, ρ, IdDict()) Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
function update!(o::Nesterov, x, Δ) function update!(o::Nesterov, x, Δ)
η, ρ = o.eta, o.rho η, ρ = o.eta, o.rho
@ -219,10 +221,46 @@ function update!(o::NADAM, x, Δ)
return Δ return Δ
end end
"""
ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
"""
ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Compose(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, γ_decay))
# Compose optimizers
"""
`Compose(Compose(...), ...)`
Compose optimizers to support inbuilt or custom gradient updates while fitting the loss.
Example:\n\n
`Compose(ADAM(), Compose(RMSProp(0.001), ExpDecay(0.02)))`
"""
mutable struct Compose mutable struct Compose
os::Vector{Any} os::Vector{Any}
end end
Compose(o...) = Compose(flattenCompose(o...))
@forward Compose.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
@forward Compose.os Base.iterate
Base.getindex(c::Compose, i::AbstractArray) = Compose(c.os[i]...)
function flattenCompose(o...)
res = []
for opt in o
if opt isa Compose
push!(res, flattenCompose(opt.os...)...)
else
push!(res, opt)
end
end
return res
end
function update!(o::Compose, x, Δ) function update!(o::Compose, x, Δ)
for opt in o.os for opt in o.os
Δ = update!(opt, x, Δ) Δ = update!(opt, x, Δ)
@ -256,3 +294,15 @@ function update!(o::ExpDecay, x, Δ)
γ = o.gamma γ = o.gamma
@. Δ += γ * x @. Δ += γ * x
end end
mutable struct DescentWeightDecay
eta::Real
gamma::Real
end
DescentWeightDecay(η = 1, γ = 0) = DescentWeightDecay(η, γ)
function update!(o::DescentWeightDecay, x, Δ)
η, γ = o.eta, o.gamma
@. x = x - η * (Δ + γ * x)
Δ
end