Add ADAMW weight-decay.

See http://www.fast.ai/2018/07/02/adam-weight-decay/ and the original
paper https://arxiv.org/abs/1711.05101.pdf for context.

I don't know what I'm doing, and this is quite possibly wrong - but on
a simple Char-RNN I have lying around on my harddisk, this seems to
improve the rate of learning consistently for different hyperparameters
vs. standard ADAM with the same decay constant.
This commit is contained in:
Jarvist Moore Frost 2018-07-03 11:11:32 +01:00
parent e92f840510
commit aee4a83c55
4 changed files with 19 additions and 3 deletions

View File

@ -22,7 +22,7 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl")
using .Optimise
using .Optimise: @epochs
export SGD, ADAM, AdaMax, Momentum, Nesterov,
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad
include("utils.jl")

View File

@ -1,7 +1,7 @@
module Optimise
export train!,
SGD, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
struct Param{T}
x::T

View File

@ -1,7 +1,7 @@
call(f, xs...) = f(xs...)
# note for optimisers: set to zero
# p.Δ at the end of the weigths update
# p.Δ at the end of the weights update
function optimiser(ps, fs...)
ps = [Param(p) for p in ps]
fs = map(ps) do p
@ -56,6 +56,14 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
"""
ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay))
"""
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)

View File

@ -5,6 +5,14 @@ function descent(p::Param, η::Real)
end
end
# Ref: https://arxiv.org/abs/1711.05101.pdf
function descentweightdecay(p::Param, η::Real, γ::Real)
function ()
@. p.x = p.x - η * (p.Δ + γ * p.x)
@. p.Δ = 0
end
end
function momentum(p::Param, ρ, η)
v = zeros(p.x)
function ()