Add ADAMW weight-decay.

See http://www.fast.ai/2018/07/02/adam-weight-decay/ and the original
paper https://arxiv.org/abs/1711.05101.pdf for context.

I don't know what I'm doing, and this is quite possibly wrong - but on
a simple Char-RNN I have lying around on my harddisk, this seems to
improve the rate of learning consistently for different hyperparameters
vs. standard ADAM with the same decay constant.
This commit is contained in:
Jarvist Moore Frost 2018-07-03 11:11:32 +01:00
parent e92f840510
commit aee4a83c55
4 changed files with 19 additions and 3 deletions

View File

@ -22,7 +22,7 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl") include("optimise/Optimise.jl")
using .Optimise using .Optimise
using .Optimise: @epochs using .Optimise: @epochs
export SGD, ADAM, AdaMax, Momentum, Nesterov, export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad RMSProp, ADAGrad, ADADelta, AMSGrad
include("utils.jl") include("utils.jl")

View File

@ -1,7 +1,7 @@
module Optimise module Optimise
export train!, export train!,
SGD, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
struct Param{T} struct Param{T}
x::T x::T

View File

@ -1,7 +1,7 @@
call(f, xs...) = f(xs...) call(f, xs...) = f(xs...)
# note for optimisers: set to zero # note for optimisers: set to zero
# p.Δ at the end of the weigths update # p.Δ at the end of the weights update
function optimiser(ps, fs...) function optimiser(ps, fs...)
ps = [Param(p) for p in ps] ps = [Param(p) for p in ps]
fs = map(ps) do p fs = map(ps) do p
@ -56,6 +56,14 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
"""
ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay))
""" """
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)

View File

@ -5,6 +5,14 @@ function descent(p::Param, η::Real)
end end
end end
# Ref: https://arxiv.org/abs/1711.05101.pdf
function descentweightdecay(p::Param, η::Real, γ::Real)
function ()
@. p.x = p.x - η * (p.Δ + γ * p.x)
@. p.Δ = 0
end
end
function momentum(p::Param, ρ, η) function momentum(p::Param, ρ, η)
v = zeros(p.x) v = zeros(p.x)
function () function ()