Merge pull request #307 from jarvist/master

Add ADAMW "Fixing Weight Decay Regularization in Adam"
This commit is contained in:
Mike J Innes 2018-07-11 19:12:58 +01:00 committed by GitHub
commit a0fd91b866
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 3 deletions

View File

@ -22,7 +22,7 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl") include("optimise/Optimise.jl")
using .Optimise using .Optimise
using .Optimise: @epochs using .Optimise: @epochs
export SGD, ADAM, AdaMax, Momentum, Nesterov, export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
include("utils.jl") include("utils.jl")

View File

@ -1,7 +1,7 @@
module Optimise module Optimise
export train!, export train!,
SGD, ADAM, AdaMax, Momentum, Nesterov, SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
struct Param{T} struct Param{T}

View File

@ -1,7 +1,7 @@
call(f, xs...) = f(xs...) call(f, xs...) = f(xs...)
# note for optimisers: set to zero # note for optimisers: set to zero
# p.Δ at the end of the weigths update # p.Δ at the end of the weights update
function optimiser(ps, fs...) function optimiser(ps, fs...)
ps = [Param(p) for p in ps] ps = [Param(p) for p in ps]
fs = map(ps) do p fs = map(ps) do p
@ -56,6 +56,14 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
"""
ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay))
""" """
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)

View File

@ -5,6 +5,14 @@ function descent(p::Param, η::Real)
end end
end end
# Ref: https://arxiv.org/abs/1711.05101.pdf
function descentweightdecay(p::Param, η::Real, γ::Real)
function ()
@. p.x = p.x - η * (p.Δ + γ * p.x)
@. p.Δ = 0
end
end
function momentum(p::Param, ρ, η) function momentum(p::Param, ρ, η)
v = zeros(p.x) v = zeros(p.x)
function () function ()