diff --git a/src/Flux.jl b/src/Flux.jl index 7125630f..eeda5492 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -22,7 +22,7 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs -export SGD, ADAM, AdaMax, Momentum, Nesterov, +export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad include("utils.jl") diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 5d5d9ea0..eb9aaa87 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,7 @@ module Optimise export train!, - SGD, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad + SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad struct Param{T} x::T diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl index 29068983..01c76391 100644 --- a/src/optimise/interface.jl +++ b/src/optimise/interface.jl @@ -1,7 +1,7 @@ call(f, xs...) = f(xs...) # note for optimisers: set to zero -# p.Δ at the end of the weigths update +# p.Δ at the end of the weights update function optimiser(ps, fs...) ps = [Param(p) for p in ps] fs = map(ps) do p @@ -56,6 +56,14 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) = ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) +""" + ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + +[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. +""" +ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = + optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay)) + """ AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 29b058ba..3a3c8945 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -5,6 +5,14 @@ function descent(p::Param, η::Real) end end +# Ref: https://arxiv.org/abs/1711.05101.pdf +function descentweightdecay(p::Param, η::Real, γ::Real) + function () + @. p.x = p.x - η * (p.Δ + γ * p.x) + @. p.Δ = 0 + end +end + function momentum(p::Param, ρ, η) v = zeros(p.x) function ()