This commit is contained in:
Mike J Innes 2018-10-31 14:58:55 +00:00
parent bebf4eb95f
commit bffaceee02
5 changed files with 36 additions and 51 deletions

View File

@ -21,7 +21,7 @@ using .Optimise
using .Optimise: @epochs
export Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, InvDecay, ExpDecay, WeightDecay, DescentWeightDecay
ADAMW, InvDecay, ExpDecay, WeightDecay
include("utils.jl")
include("onehot.jl")

View File

@ -1,17 +1,6 @@
using Base: depwarn
function check_decay(opt, decay)
if decay == 0.
opt = opt
else
if opt isa ADAMW
opt = Optimiser(opt, WeightDecay(decay))
else
opt = Optimiser(opt, InvDecay(decay))
end
end
opt
end
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
# legacy update rule
function updaterule(opt, ps)
@ -24,7 +13,7 @@ function updaterule(opt, ps)
end
function Descent(params::AbstractArray, η = 0.1; decay = 0.)
depwarn("Descent(ps::Param) is deprecated; use Descent(η::Float64) instead", :Descent)
depwarn("Descent(params) is deprecated; use Descent(η::Float64) instead", :Descent)
ps = params
opt = Descent(η)
@ -33,7 +22,7 @@ function Descent(params::AbstractArray, η = 0.1; decay = 0.)
end
function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.)
depwarn("Momentum(ps::Param) is deprecated; use Momentum(η::Float64) instead", :Momentum)
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
ps = params
opt = Momentum(η, ρ)
@ -42,7 +31,7 @@ function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.)
end
function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("Nesterov(ps::Param) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
ps = params
opt = Nesterov(η, ρ)
@ -51,7 +40,7 @@ function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
end
function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("RMSProp(ps::Param) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
ps = params
opt = RMSProp(η, ρ)
@ -60,7 +49,7 @@ function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
end
function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAM(ps::Param) is deprecated; use ADAM(η::Float64) instead", :ADAM)
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
ps = params
β = (β1, β2)
@ -70,7 +59,7 @@ function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay =
end
function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.)
depwarn("ADAGrad(ps::Param) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
ps = params
opt = ADAGrad(η)
@ -79,7 +68,7 @@ function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.)
end
function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.)
depwarn("ADADelta(ps::Param) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
ps = params
opt = ADADelta(ρ)
@ -88,7 +77,7 @@ function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.)
end
function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AdaMax(ps::Param) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
ps = params
β = (β1, β2)
@ -98,7 +87,7 @@ function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay
end
function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AMSGrad(ps::Param) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
ps = params
β = (β1, β2)
@ -108,7 +97,7 @@ function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, deca
end
function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("NADAM(ps::Param) is deprecated; use NADAM(η::Float64) instead", :NADAM)
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
ps = params
β = (β1, β2)
@ -118,21 +107,26 @@ function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay
end
function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAMW(ps::Param) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
ps = params
β = (β1, β2)
opt = ADAMW(η, β)
opt = check_decay(opt, decay)
decay != 0 && (opt = Optimiser(opt, WeightDecay(decay)))
updaterule(opt, ps)
end
# Train function
function train!(loss::Function, data, opt; cb = () -> ())
depwarn("train!(loss, data, opt; cb) is deprecated; use train!(loss, params, data, opt; cb) instead", :train)
if fieldnames(typeof(opt)) !== ()
train!(loss, opt.ps, data, opt.opt; cb = cb)
else
train!(loss, (), data, opt; cb = cb)
end
# Old training loop
struct OldOptimiser
func
end
update!(opt::OldOptimiser, ps) = opt.func()
# Train function
function train!(loss, data, opt; cb = () -> ())
depwarn("train!(loss, data, opt) is deprecated; use train!(loss, params, data, opt) instead", :train!)
train!(loss, (), data, OldOptimiser(opt); cb = cb)
end

View File

@ -17,6 +17,7 @@ mutable struct Descent
end
Descent() = Descent(0.1)
function update!(o::Descent, x, Δ)
Δ .*= o.eta
end
@ -152,7 +153,7 @@ function update!(o::ADAGrad, x, Δ)
end
"""
ADADelta(params; ρ = 0.9, ϵ = 1e-8)
ADADelta(ρ = 0.9, ϵ = 1e-8)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning.
@ -222,16 +223,18 @@ function update!(o::NADAM, x, Δ)
end
"""
ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
ADAMW((η = 0.001, β = (0.9, 0.999), decay = 0)
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
"""
ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, wd = 0) = Optimiser(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, wd))
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
Optimiser(ADAM(η, β), WeightDecay(wd))
# Compose optimizers
"""
Optimiser(a, b, c...)
Combine several optimisers into one; each optimiser produces a modified gradient
that will be fed into the next, and this is finally applied to the parameter as
usual.
@ -254,8 +257,6 @@ function update!(o::Optimiser, x, Δ)
return Δ
end
# TODO: decay
mutable struct InvDecay
gamma::Float64
state::IdDict
@ -284,9 +285,7 @@ ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(op
function update!(o::ExpDecay, x, Δ)
η, s, decay = o.eta, o.step, o.decay
n = o.current[x] = get(o.current, x, 0) + 1
flag = false
count(x -> x%s == 0, values(o.current)) == 1 && (flag = true)
if o.current[x]%s == 0 && flag
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
η = max(η * decay^(s / n), o.clip)
o.eta = η
end
@ -298,11 +297,8 @@ mutable struct WeightDecay
end
WeightDecay() = WeightDecay(0)
function update!(o::WeightDecay, x, Δ)
wd = o.wd
@. Δ += wd * x
end
DescentWeightDecay(η = 1, wd = 0) = Optimiser(WeightDecay(wd), Descent(η))
update!(opt::Function, ps) = opt()

View File

@ -64,12 +64,6 @@ Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb)
opt = runall(opt)
opt = try
opt()
opt.opt
catch
opt
end
@progress for d in data
try
l = loss(d...)

View File

@ -44,8 +44,9 @@ end
l = param(1)
Flux.train!(() -> (sleep(0.1); i += 1; l),
(),
Iterators.repeated((), 100),
() -> (),
Descent(),
cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
@test 3 < i < 50