From aee4a83c556c748ce72a69852b0b3d5e72998476 Mon Sep 17 00:00:00 2001 From: Jarvist Moore Frost Date: Tue, 3 Jul 2018 11:11:32 +0100 Subject: [PATCH] Add ADAMW weight-decay. See http://www.fast.ai/2018/07/02/adam-weight-decay/ and the original paper https://arxiv.org/abs/1711.05101.pdf for context. I don't know what I'm doing, and this is quite possibly wrong - but on a simple Char-RNN I have lying around on my harddisk, this seems to improve the rate of learning consistently for different hyperparameters vs. standard ADAM with the same decay constant. --- src/Flux.jl | 2 +- src/optimise/Optimise.jl | 2 +- src/optimise/interface.jl | 10 +++++++++- src/optimise/optimisers.jl | 8 ++++++++ 4 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 7125630f..eeda5492 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -22,7 +22,7 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs -export SGD, ADAM, AdaMax, Momentum, Nesterov, +export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad include("utils.jl") diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 5d5d9ea0..eb9aaa87 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,7 @@ module Optimise export train!, - SGD, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad + SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad struct Param{T} x::T diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl index 29068983..01c76391 100644 --- a/src/optimise/interface.jl +++ b/src/optimise/interface.jl @@ -1,7 +1,7 @@ call(f, xs...) = f(xs...) # note for optimisers: set to zero -# p.Δ at the end of the weigths update +# p.Δ at the end of the weights update function optimiser(ps, fs...) ps = [Param(p) for p in ps] fs = map(ps) do p @@ -56,6 +56,14 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) = ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) +""" + ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + +[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. +""" +ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = + optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay)) + """ AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 29b058ba..3a3c8945 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -5,6 +5,14 @@ function descent(p::Param, η::Real) end end +# Ref: https://arxiv.org/abs/1711.05101.pdf +function descentweightdecay(p::Param, η::Real, γ::Real) + function () + @. p.x = p.x - η * (p.Δ + γ * p.x) + @. p.Δ = 0 + end +end + function momentum(p::Param, ρ, η) v = zeros(p.x) function ()