From ea9b5471faaf62d0564f151e597354d2aa6acc78 Mon Sep 17 00:00:00 2001 From: tejank10 Date: Tue, 3 Apr 2018 01:27:22 +0530 Subject: [PATCH 1/3] NADAM optimizer --- src/Flux.jl | 2 +- src/optimise/Optimise.jl | 2 +- src/optimise/interface.jl | 9 +++++++++ src/optimise/optimisers.jl | 13 +++++++++++++ test/optimise.jl | 2 +- 5 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 6288dba7..05182f3f 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -9,7 +9,7 @@ using MacroTools: @forward export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D, Dropout, LayerNorm, BatchNorm, - SGD, ADAM, Momentum, Nesterov, AMSGrad, + SGD, ADAM, Momentum, Nesterov, AMSGrad, NADAM, param, params, mapleaves, cpu, gpu @reexport using NNlib diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index acec542e..7ac466fb 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,7 @@ module Optimise export update!, params, train!, - SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad + SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM struct Param{T} x::T diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl index 42b05dc8..505ca92b 100644 --- a/src/optimise/interface.jl +++ b/src/optimise/interface.jl @@ -82,3 +82,12 @@ tuning. """ AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) + +""" + NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + +[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other +than learning rate don't need tuning. +""" +NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = + optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index c09e6131..bc1f9805 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -74,6 +74,19 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, end end +function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) + mt = zeros(p.x) + vt = zeros(p.x) + β1p, β2p = β1, β2 + function () + @. mt = β1 * mt + (1 - β1) * p.Δ + @. vt = β2 * vt + (1 - β2) * p.Δ^2 + @. p.Δ = (β1 * mt + (1 - β1) * p.Δ) / (1 - β1p) / (√(vt / (1 - β2p)) + ϵ) * η + β1p *= β1 + β2p *= β2 + end +end + clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh) function expdecay(p::Param, γ::Real) diff --git a/test/optimise.jl b/test/optimise.jl index d57e4985..128d6231 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -3,7 +3,7 @@ using Flux.Tracker @testset "Optimise" begin w = randn(10, 10) - @testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad] + @testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM] w′ = param(randn(10, 10)) loss(x) = Flux.mse(w*x, w′*x) opt = Opt([w′]) From 3ead66298736a29ab6c11b0c4fdb8c996a903471 Mon Sep 17 00:00:00 2001 From: tejank10 Date: Wed, 4 Apr 2018 15:18:44 +0530 Subject: [PATCH 2/3] Update rule fixed --- src/optimise/optimisers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index bc1f9805..31d47c32 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -81,7 +81,7 @@ function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, function () @. mt = β1 * mt + (1 - β1) * p.Δ @. vt = β2 * vt + (1 - β2) * p.Δ^2 - @. p.Δ = (β1 * mt + (1 - β1) * p.Δ) / (1 - β1p) / (√(vt / (1 - β2p)) + ϵ) * η + @. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / √(vt * β2 / (1 - β2p) + ϵ) * η β1p *= β1 β2p *= β2 end From 65847bb745e39f2e2a3fa5b50449f1c305b6a411 Mon Sep 17 00:00:00 2001 From: tejank10 Date: Wed, 4 Apr 2018 15:25:20 +0530 Subject: [PATCH 3/3] moved epsilon into sqrt --- src/optimise/optimisers.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 31d47c32..01a92a70 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -27,7 +27,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) acc = zeros(p.x) function () @. acc = ρ * acc + (1 - ρ) * p.Δ^2 - @. p.Δ *= η / (√acc + ϵ) + @. p.Δ *= η / √(acc + ϵ) end end @@ -35,7 +35,7 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) acc = zeros(p.x) .+ ϵ function () @. acc += p.Δ^2 - @. p.Δ *= η / √acc + @. p.Δ *= η / √(acc + ϵ) end end @@ -56,7 +56,7 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ function () @. mt = β1 * mt + (1 - β1) * p.Δ @. vt = β2 * vt + (1 - β2) * p.Δ^2 - @. p.Δ = mt / (1 - β1p) / (√(vt / (1 - β2p)) + ϵ) * η + @. p.Δ = mt / (1 - β1p) / √(vt / (1 - β2p) + ϵ) * η β1p *= β1 β2p *= β2 end