Add ADAMW weight-decay.
See http://www.fast.ai/2018/07/02/adam-weight-decay/ and the original paper https://arxiv.org/abs/1711.05101.pdf for context. I don't know what I'm doing, and this is quite possibly wrong - but on a simple Char-RNN I have lying around on my harddisk, this seems to improve the rate of learning consistently for different hyperparameters vs. standard ADAM with the same decay constant.
This commit is contained in:
parent
e92f840510
commit
aee4a83c55
@ -22,7 +22,7 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
||||
include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
using .Optimise: @epochs
|
||||
export SGD, ADAM, AdaMax, Momentum, Nesterov,
|
||||
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
||||
RMSProp, ADAGrad, ADADelta, AMSGrad
|
||||
|
||||
include("utils.jl")
|
||||
|
@ -1,7 +1,7 @@
|
||||
module Optimise
|
||||
|
||||
export train!,
|
||||
SGD, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
|
||||
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
|
||||
|
||||
struct Param{T}
|
||||
x::T
|
||||
|
@ -1,7 +1,7 @@
|
||||
call(f, xs...) = f(xs...)
|
||||
|
||||
# note for optimisers: set to zero
|
||||
# p.Δ at the end of the weigths update
|
||||
# p.Δ at the end of the weights update
|
||||
function optimiser(ps, fs...)
|
||||
ps = [Param(p) for p in ps]
|
||||
fs = map(ps) do p
|
||||
@ -56,6 +56,14 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
|
||||
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
||||
|
||||
"""
|
||||
ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
|
||||
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
||||
"""
|
||||
ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay))
|
||||
|
||||
"""
|
||||
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
|
||||
|
@ -5,6 +5,14 @@ function descent(p::Param, η::Real)
|
||||
end
|
||||
end
|
||||
|
||||
# Ref: https://arxiv.org/abs/1711.05101.pdf
|
||||
function descentweightdecay(p::Param, η::Real, γ::Real)
|
||||
function ()
|
||||
@. p.x = p.x - η * (p.Δ + γ * p.x)
|
||||
@. p.Δ = 0
|
||||
end
|
||||
end
|
||||
|
||||
function momentum(p::Param, ρ, η)
|
||||
v = zeros(p.x)
|
||||
function ()
|
||||
|
Loading…
Reference in New Issue
Block a user