Merge pull request #842 from baggepinnen/patch-4

Add RADAM optimizer
This commit is contained in:
Mike J Innes 2019-09-02 14:36:40 +01:00 committed by GitHub
commit 3c1ac84676
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 34 additions and 3 deletions

View File

@ -1,6 +1,7 @@
# v0.9.0 # v0.9.0
* [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor. * [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor.
* New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures. * New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures.
* New [RADAM](https://github.com/FluxML/Flux.jl/pull/842) optimiser.
# v0.8.0 # v0.8.0

View File

@ -22,7 +22,7 @@ using .Optimise
using .Optimise: @epochs using .Optimise: @epochs
export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, InvDecay, ExpDecay, WeightDecay ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
using CUDAapi using CUDAapi
if has_cuda() if has_cuda()

View File

@ -2,7 +2,7 @@ module Optimise
export train!, export train!,
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM,
InvDecay, ExpDecay, WeightDecay, stop, Optimiser InvDecay, ExpDecay, WeightDecay, stop, Optimiser
include("optimisers.jl") include("optimisers.jl")

View File

@ -108,6 +108,36 @@ function apply!(o::ADAM, x, Δ)
return Δ return Δ
end end
"""
RADAM(η = 0.001, β = (0.9, 0.999))
[RADAM](https://arxiv.org/pdf/1908.03265v1.pdf) optimiser (Rectified ADAM).
"""
mutable struct RADAM
eta::Float64
beta::Tuple{Float64,Float64}
state::IdDict
end
RADAM(η = 0.001, β = (0.9, 0.999)) = RADAM(η, β, IdDict())
function apply!(o::RADAM, x, Δ)
η, β = o.eta, o.beta
ρ∞ = 2/(1-β[2])-1
mt, vt, βp, t = get!(o.state, x, (zero(x), zero(x), β, 1))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
ρ = ρ∞ - 2t*βp[2]/(1-βp[2])
if ρ > 4
r = sqrt((ρ-4)*(ρ-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ))
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η * r
else
@. Δ = mt / (1 - βp[1]) * η
end
o.state[x] = (mt, vt, βp .* β, t+1)
return Δ
end
""" """
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08) AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)

View File

@ -5,7 +5,7 @@ using Test
@testset "Optimise" begin @testset "Optimise" begin
w = randn(10, 10) w = randn(10, 10)
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(), @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(), NADAM(), RADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
Momentum()] Momentum()]
w = param(randn(10, 10)) w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)