Merge pull request #842 from baggepinnen/patch-4

Add RADAM optimizer
This commit is contained in:
Mike J Innes 2019-09-02 14:36:40 +01:00 committed by GitHub
commit 3c1ac84676
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 34 additions and 3 deletions

View File

@ -1,6 +1,7 @@
# v0.9.0
* [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor.
* New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures.
* New [RADAM](https://github.com/FluxML/Flux.jl/pull/842) optimiser.
# v0.8.0

View File

@ -22,7 +22,7 @@ using .Optimise
using .Optimise: @epochs
export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, InvDecay, ExpDecay, WeightDecay
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
using CUDAapi
if has_cuda()

View File

@ -2,7 +2,7 @@ module Optimise
export train!,
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM,
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
include("optimisers.jl")

View File

@ -108,6 +108,36 @@ function apply!(o::ADAM, x, Δ)
return Δ
end
"""
RADAM(η = 0.001, β = (0.9, 0.999))
[RADAM](https://arxiv.org/pdf/1908.03265v1.pdf) optimiser (Rectified ADAM).
"""
mutable struct RADAM
eta::Float64
beta::Tuple{Float64,Float64}
state::IdDict
end
RADAM(η = 0.001, β = (0.9, 0.999)) = RADAM(η, β, IdDict())
function apply!(o::RADAM, x, Δ)
η, β = o.eta, o.beta
ρ∞ = 2/(1-β[2])-1
mt, vt, βp, t = get!(o.state, x, (zero(x), zero(x), β, 1))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
ρ = ρ∞ - 2t*βp[2]/(1-βp[2])
if ρ > 4
r = sqrt((ρ-4)*(ρ-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ))
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η * r
else
@. Δ = mt / (1 - βp[1]) * η
end
o.state[x] = (mt, vt, βp .* β, t+1)
return Δ
end
"""
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)

View File

@ -5,7 +5,7 @@ using Test
@testset "Optimise" begin
w = randn(10, 10)
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
NADAM(), RADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
Momentum()]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)