tweaks
This commit is contained in:
parent
bebf4eb95f
commit
bffaceee02
|
@ -21,7 +21,7 @@ using .Optimise
|
|||
using .Optimise: @epochs
|
||||
export Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
|
||||
ADAMW, InvDecay, ExpDecay, WeightDecay, DescentWeightDecay
|
||||
ADAMW, InvDecay, ExpDecay, WeightDecay
|
||||
|
||||
include("utils.jl")
|
||||
include("onehot.jl")
|
||||
|
|
|
@ -1,17 +1,6 @@
|
|||
using Base: depwarn
|
||||
|
||||
function check_decay(opt, decay)
|
||||
if decay == 0.
|
||||
opt = opt
|
||||
else
|
||||
if opt isa ADAMW
|
||||
opt = Optimiser(opt, WeightDecay(decay))
|
||||
else
|
||||
opt = Optimiser(opt, InvDecay(decay))
|
||||
end
|
||||
end
|
||||
opt
|
||||
end
|
||||
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
|
||||
|
||||
# legacy update rule
|
||||
function updaterule(opt, ps)
|
||||
|
@ -24,7 +13,7 @@ function updaterule(opt, ps)
|
|||
end
|
||||
|
||||
function Descent(params::AbstractArray, η = 0.1; decay = 0.)
|
||||
depwarn("Descent(ps::Param) is deprecated; use Descent(η::Float64) instead", :Descent)
|
||||
depwarn("Descent(params) is deprecated; use Descent(η::Float64) instead", :Descent)
|
||||
|
||||
ps = params
|
||||
opt = Descent(η)
|
||||
|
@ -33,7 +22,7 @@ function Descent(params::AbstractArray, η = 0.1; decay = 0.)
|
|||
end
|
||||
|
||||
function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.)
|
||||
depwarn("Momentum(ps::Param) is deprecated; use Momentum(η::Float64) instead", :Momentum)
|
||||
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
|
||||
|
||||
ps = params
|
||||
opt = Momentum(η, ρ)
|
||||
|
@ -42,7 +31,7 @@ function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.)
|
|||
end
|
||||
|
||||
function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
|
||||
depwarn("Nesterov(ps::Param) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
|
||||
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
|
||||
|
||||
ps = params
|
||||
opt = Nesterov(η, ρ)
|
||||
|
@ -51,7 +40,7 @@ function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
|
|||
end
|
||||
|
||||
function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
|
||||
depwarn("RMSProp(ps::Param) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
|
||||
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
|
||||
|
||||
ps = params
|
||||
opt = RMSProp(η, ρ)
|
||||
|
@ -60,7 +49,7 @@ function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
|
|||
end
|
||||
|
||||
function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("ADAM(ps::Param) is deprecated; use ADAM(η::Float64) instead", :ADAM)
|
||||
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
|
@ -70,7 +59,7 @@ function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay =
|
|||
end
|
||||
|
||||
function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.)
|
||||
depwarn("ADAGrad(ps::Param) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
|
||||
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
|
||||
|
||||
ps = params
|
||||
opt = ADAGrad(η)
|
||||
|
@ -79,7 +68,7 @@ function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.)
|
|||
end
|
||||
|
||||
function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.)
|
||||
depwarn("ADADelta(ps::Param) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
|
||||
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
|
||||
|
||||
ps = params
|
||||
opt = ADADelta(ρ)
|
||||
|
@ -88,7 +77,7 @@ function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.)
|
|||
end
|
||||
|
||||
function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("AdaMax(ps::Param) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
|
||||
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
|
@ -98,7 +87,7 @@ function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay
|
|||
end
|
||||
|
||||
function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("AMSGrad(ps::Param) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
|
||||
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
|
@ -108,7 +97,7 @@ function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, deca
|
|||
end
|
||||
|
||||
function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("NADAM(ps::Param) is deprecated; use NADAM(η::Float64) instead", :NADAM)
|
||||
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
|
@ -118,21 +107,26 @@ function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay
|
|||
end
|
||||
|
||||
function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("ADAMW(ps::Param) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
|
||||
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = ADAMW(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
decay != 0 && (opt = Optimiser(opt, WeightDecay(decay)))
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
# Train function
|
||||
function train!(loss::Function, data, opt; cb = () -> ())
|
||||
depwarn("train!(loss, data, opt; cb) is deprecated; use train!(loss, params, data, opt; cb) instead", :train)
|
||||
if fieldnames(typeof(opt)) !== ()
|
||||
train!(loss, opt.ps, data, opt.opt; cb = cb)
|
||||
else
|
||||
train!(loss, (), data, opt; cb = cb)
|
||||
end
|
||||
# Old training loop
|
||||
|
||||
struct OldOptimiser
|
||||
func
|
||||
end
|
||||
|
||||
update!(opt::OldOptimiser, ps) = opt.func()
|
||||
|
||||
# Train function
|
||||
function train!(loss, data, opt; cb = () -> ())
|
||||
depwarn("train!(loss, data, opt) is deprecated; use train!(loss, params, data, opt) instead", :train!)
|
||||
train!(loss, (), data, OldOptimiser(opt); cb = cb)
|
||||
end
|
||||
|
|
|
@ -17,6 +17,7 @@ mutable struct Descent
|
|||
end
|
||||
|
||||
Descent() = Descent(0.1)
|
||||
|
||||
function update!(o::Descent, x, Δ)
|
||||
Δ .*= o.eta
|
||||
end
|
||||
|
@ -152,7 +153,7 @@ function update!(o::ADAGrad, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
ADADelta(params; ρ = 0.9, ϵ = 1e-8)
|
||||
ADADelta(ρ = 0.9, ϵ = 1e-8)
|
||||
|
||||
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
|
||||
tuning.
|
||||
|
@ -222,16 +223,18 @@ function update!(o::NADAM, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
ADAMW((η = 0.001, β = (0.9, 0.999), decay = 0)
|
||||
|
||||
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
||||
"""
|
||||
ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, wd = 0) = Optimiser(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, wd))
|
||||
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
|
||||
Optimiser(ADAM(η, β), WeightDecay(wd))
|
||||
|
||||
# Compose optimizers
|
||||
|
||||
"""
|
||||
Optimiser(a, b, c...)
|
||||
|
||||
Combine several optimisers into one; each optimiser produces a modified gradient
|
||||
that will be fed into the next, and this is finally applied to the parameter as
|
||||
usual.
|
||||
|
@ -254,8 +257,6 @@ function update!(o::Optimiser, x, Δ)
|
|||
return Δ
|
||||
end
|
||||
|
||||
# TODO: decay
|
||||
|
||||
mutable struct InvDecay
|
||||
gamma::Float64
|
||||
state::IdDict
|
||||
|
@ -284,9 +285,7 @@ ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(op
|
|||
function update!(o::ExpDecay, x, Δ)
|
||||
η, s, decay = o.eta, o.step, o.decay
|
||||
n = o.current[x] = get(o.current, x, 0) + 1
|
||||
flag = false
|
||||
count(x -> x%s == 0, values(o.current)) == 1 && (flag = true)
|
||||
if o.current[x]%s == 0 && flag
|
||||
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
|
||||
η = max(η * decay^(s / n), o.clip)
|
||||
o.eta = η
|
||||
end
|
||||
|
@ -298,11 +297,8 @@ mutable struct WeightDecay
|
|||
end
|
||||
|
||||
WeightDecay() = WeightDecay(0)
|
||||
|
||||
function update!(o::WeightDecay, x, Δ)
|
||||
wd = o.wd
|
||||
@. Δ += wd * x
|
||||
end
|
||||
|
||||
DescentWeightDecay(η = 1, wd = 0) = Optimiser(WeightDecay(wd), Descent(η))
|
||||
|
||||
update!(opt::Function, ps) = opt()
|
||||
|
|
|
@ -64,12 +64,6 @@ Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
|||
function train!(loss, ps, data, opt; cb = () -> ())
|
||||
cb = runall(cb)
|
||||
opt = runall(opt)
|
||||
opt = try
|
||||
opt()
|
||||
opt.opt
|
||||
catch
|
||||
opt
|
||||
end
|
||||
@progress for d in data
|
||||
try
|
||||
l = loss(d...)
|
||||
|
|
|
@ -44,8 +44,9 @@ end
|
|||
l = param(1)
|
||||
|
||||
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||
(),
|
||||
Iterators.repeated((), 100),
|
||||
() -> (),
|
||||
Descent(),
|
||||
cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
|
||||
|
||||
@test 3 < i < 50
|
||||
|
|
Loading…
Reference in New Issue