From ebbad0d135a996dc807909201dccc74493936262 Mon Sep 17 00:00:00 2001 From: Fredrik Bagge Carlson Date: Mon, 19 Aug 2019 12:22:32 +0800 Subject: [PATCH] Add RADAM optimizer --- src/optimise/optimisers.jl | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 939a4678..a3f4cdbd 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -108,6 +108,36 @@ function apply!(o::ADAM, x, Δ) return Δ end +""" + RADAM(η = 0.001, β = (0.9, 0.999)) + +[RADAM](https://arxiv.org/pdf/1908.03265v1.pdf) optimiser (Rectified ADAM). +""" +mutable struct RADAM + eta::Float64 + beta::Tuple{Float64,Float64} + state::IdDict +end + +RADAM(η = 0.001, β = (0.9, 0.999)) = RADAM(η, β, IdDict()) + +function apply!(o::RADAM, x, Δ) + η, β = o.eta, o.beta + ρ∞ = 2/(1-β[2])-1 + mt, vt, βp, t = get!(o.state, x, (zero(x), zero(x), β, 1)) + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. vt = β[2] * vt + (1 - β[2]) * Δ^2 + ρ = ρ∞ - 2t*βp[2]/(1-βp[2]) + if ρ > 4 + r = sqrt((ρ-4)*(ρ-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ)) + @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η * r + else + @. Δ = mt / (1 - βp[1]) * η + end + o.state[x] = (mt, vt, βp .* β, t+1) + return Δ +end + """ AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)