This commit is contained in:
Mike J Innes 2018-10-31 14:58:55 +00:00
parent bebf4eb95f
commit bffaceee02
5 changed files with 36 additions and 51 deletions

View File

@ -21,7 +21,7 @@ using .Optimise
using .Optimise: @epochs using .Optimise: @epochs
export Descent, ADAM, Momentum, Nesterov, RMSProp, export Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, InvDecay, ExpDecay, WeightDecay, DescentWeightDecay ADAMW, InvDecay, ExpDecay, WeightDecay
include("utils.jl") include("utils.jl")
include("onehot.jl") include("onehot.jl")

View File

@ -1,17 +1,6 @@
using Base: depwarn using Base: depwarn
function check_decay(opt, decay) check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
if decay == 0.
opt = opt
else
if opt isa ADAMW
opt = Optimiser(opt, WeightDecay(decay))
else
opt = Optimiser(opt, InvDecay(decay))
end
end
opt
end
# legacy update rule # legacy update rule
function updaterule(opt, ps) function updaterule(opt, ps)
@ -24,7 +13,7 @@ function updaterule(opt, ps)
end end
function Descent(params::AbstractArray, η = 0.1; decay = 0.) function Descent(params::AbstractArray, η = 0.1; decay = 0.)
depwarn("Descent(ps::Param) is deprecated; use Descent(η::Float64) instead", :Descent) depwarn("Descent(params) is deprecated; use Descent(η::Float64) instead", :Descent)
ps = params ps = params
opt = Descent(η) opt = Descent(η)
@ -33,7 +22,7 @@ function Descent(params::AbstractArray, η = 0.1; decay = 0.)
end end
function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.) function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.)
depwarn("Momentum(ps::Param) is deprecated; use Momentum(η::Float64) instead", :Momentum) depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
ps = params ps = params
opt = Momentum(η, ρ) opt = Momentum(η, ρ)
@ -42,7 +31,7 @@ function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.)
end end
function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("Nesterov(ps::Param) is deprecated; use Nesterov(η::Float64) instead", :Nesterov) depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
ps = params ps = params
opt = Nesterov(η, ρ) opt = Nesterov(η, ρ)
@ -51,7 +40,7 @@ function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
end end
function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("RMSProp(ps::Param) is deprecated; use RMSProp(η::Float64) instead", :RMSProp) depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
ps = params ps = params
opt = RMSProp(η, ρ) opt = RMSProp(η, ρ)
@ -60,7 +49,7 @@ function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
end end
function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAM(ps::Param) is deprecated; use ADAM(η::Float64) instead", :ADAM) depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
ps = params ps = params
β = (β1, β2) β = (β1, β2)
@ -70,7 +59,7 @@ function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay =
end end
function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.) function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.)
depwarn("ADAGrad(ps::Param) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad) depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
ps = params ps = params
opt = ADAGrad(η) opt = ADAGrad(η)
@ -79,7 +68,7 @@ function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.)
end end
function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.) function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.)
depwarn("ADADelta(ps::Param) is deprecated; use ADADelta(η::Float64) instead", :ADADelta) depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
ps = params ps = params
opt = ADADelta(ρ) opt = ADADelta(ρ)
@ -88,7 +77,7 @@ function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.)
end end
function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AdaMax(ps::Param) is deprecated; use AdaMax(η::Float64) instead", :AdaMax) depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
ps = params ps = params
β = (β1, β2) β = (β1, β2)
@ -98,7 +87,7 @@ function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay
end end
function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AMSGrad(ps::Param) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad) depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
ps = params ps = params
β = (β1, β2) β = (β1, β2)
@ -108,7 +97,7 @@ function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, deca
end end
function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("NADAM(ps::Param) is deprecated; use NADAM(η::Float64) instead", :NADAM) depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
ps = params ps = params
β = (β1, β2) β = (β1, β2)
@ -118,21 +107,26 @@ function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay
end end
function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAMW(ps::Param) is deprecated; use ADAMW(η::Float64) instead", :ADAMW) depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
ps = params ps = params
β = (β1, β2) β = (β1, β2)
opt = ADAMW(η, β) opt = ADAMW(η, β)
opt = check_decay(opt, decay) opt = check_decay(opt, decay)
decay != 0 && (opt = Optimiser(opt, WeightDecay(decay)))
updaterule(opt, ps) updaterule(opt, ps)
end end
# Train function # Old training loop
function train!(loss::Function, data, opt; cb = () -> ())
depwarn("train!(loss, data, opt; cb) is deprecated; use train!(loss, params, data, opt; cb) instead", :train) struct OldOptimiser
if fieldnames(typeof(opt)) !== () func
train!(loss, opt.ps, data, opt.opt; cb = cb) end
else
train!(loss, (), data, opt; cb = cb) update!(opt::OldOptimiser, ps) = opt.func()
end
# Train function
function train!(loss, data, opt; cb = () -> ())
depwarn("train!(loss, data, opt) is deprecated; use train!(loss, params, data, opt) instead", :train!)
train!(loss, (), data, OldOptimiser(opt); cb = cb)
end end

View File

@ -17,6 +17,7 @@ mutable struct Descent
end end
Descent() = Descent(0.1) Descent() = Descent(0.1)
function update!(o::Descent, x, Δ) function update!(o::Descent, x, Δ)
Δ .*= o.eta Δ .*= o.eta
end end
@ -152,7 +153,7 @@ function update!(o::ADAGrad, x, Δ)
end end
""" """
ADADelta(params; ρ = 0.9, ϵ = 1e-8) ADADelta(ρ = 0.9, ϵ = 1e-8)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need [ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning. tuning.
@ -222,16 +223,18 @@ function update!(o::NADAM, x, Δ)
end end
""" """
ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) ADAMW((η = 0.001, β = (0.9, 0.999), decay = 0)
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. [ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
""" """
ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, wd = 0) = Optimiser(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, wd)) ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
Optimiser(ADAM(η, β), WeightDecay(wd))
# Compose optimizers # Compose optimizers
""" """
Optimiser(a, b, c...) Optimiser(a, b, c...)
Combine several optimisers into one; each optimiser produces a modified gradient Combine several optimisers into one; each optimiser produces a modified gradient
that will be fed into the next, and this is finally applied to the parameter as that will be fed into the next, and this is finally applied to the parameter as
usual. usual.
@ -254,8 +257,6 @@ function update!(o::Optimiser, x, Δ)
return Δ return Δ
end end
# TODO: decay
mutable struct InvDecay mutable struct InvDecay
gamma::Float64 gamma::Float64
state::IdDict state::IdDict
@ -284,9 +285,7 @@ ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(op
function update!(o::ExpDecay, x, Δ) function update!(o::ExpDecay, x, Δ)
η, s, decay = o.eta, o.step, o.decay η, s, decay = o.eta, o.step, o.decay
n = o.current[x] = get(o.current, x, 0) + 1 n = o.current[x] = get(o.current, x, 0) + 1
flag = false if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
count(x -> x%s == 0, values(o.current)) == 1 && (flag = true)
if o.current[x]%s == 0 && flag
η = max(η * decay^(s / n), o.clip) η = max(η * decay^(s / n), o.clip)
o.eta = η o.eta = η
end end
@ -298,11 +297,8 @@ mutable struct WeightDecay
end end
WeightDecay() = WeightDecay(0) WeightDecay() = WeightDecay(0)
function update!(o::WeightDecay, x, Δ) function update!(o::WeightDecay, x, Δ)
wd = o.wd wd = o.wd
@. Δ += wd * x @. Δ += wd * x
end end
DescentWeightDecay(η = 1, wd = 0) = Optimiser(WeightDecay(wd), Descent(η))
update!(opt::Function, ps) = opt()

View File

@ -64,12 +64,6 @@ Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
function train!(loss, ps, data, opt; cb = () -> ()) function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb) cb = runall(cb)
opt = runall(opt) opt = runall(opt)
opt = try
opt()
opt.opt
catch
opt
end
@progress for d in data @progress for d in data
try try
l = loss(d...) l = loss(d...)

View File

@ -44,8 +44,9 @@ end
l = param(1) l = param(1)
Flux.train!(() -> (sleep(0.1); i += 1; l), Flux.train!(() -> (sleep(0.1); i += 1; l),
(),
Iterators.repeated((), 100), Iterators.repeated((), 100),
() -> (), Descent(),
cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1)) cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
@test 3 < i < 50 @test 3 < i < 50