NADAM optimizer

This commit is contained in:
tejank10 2018-04-03 01:27:22 +05:30
parent 4320738d87
commit ea9b5471fa
5 changed files with 25 additions and 3 deletions

View File

@ -9,7 +9,7 @@ using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D, export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D,
Dropout, LayerNorm, BatchNorm, Dropout, LayerNorm, BatchNorm,
SGD, ADAM, Momentum, Nesterov, AMSGrad, SGD, ADAM, Momentum, Nesterov, AMSGrad, NADAM,
param, params, mapleaves, cpu, gpu param, params, mapleaves, cpu, gpu
@reexport using NNlib @reexport using NNlib

View File

@ -1,7 +1,7 @@
module Optimise module Optimise
export update!, params, train!, export update!, params, train!,
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
struct Param{T} struct Param{T}
x::T x::T

View File

@ -82,3 +82,12 @@ tuning.
""" """
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
"""
NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
than learning rate don't need tuning.
"""
NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))

View File

@ -74,6 +74,19 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999,
end end
end end
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
vt = zeros(p.x)
β1p, β2p = β1, β2
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = (β1 * mt + (1 - β1) * p.Δ) / (1 - β1p) / ((vt / (1 - β2p)) + ϵ) * η
β1p *= β1
β2p *= β2
end
end
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh) clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
function expdecay(p::Param, γ::Real) function expdecay(p::Param, γ::Real)

View File

@ -3,7 +3,7 @@ using Flux.Tracker
@testset "Optimise" begin @testset "Optimise" begin
w = randn(10, 10) w = randn(10, 10)
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad] @testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
w = param(randn(10, 10)) w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w]) opt = Opt([w])