NADAM optimizer
This commit is contained in:
parent
4320738d87
commit
ea9b5471fa
|
@ -9,7 +9,7 @@ using MacroTools: @forward
|
|||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D,
|
||||
Dropout, LayerNorm, BatchNorm,
|
||||
SGD, ADAM, Momentum, Nesterov, AMSGrad,
|
||||
SGD, ADAM, Momentum, Nesterov, AMSGrad, NADAM,
|
||||
param, params, mapleaves, cpu, gpu
|
||||
|
||||
@reexport using NNlib
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
module Optimise
|
||||
|
||||
export update!, params, train!,
|
||||
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
|
||||
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
||||
|
||||
struct Param{T}
|
||||
x::T
|
||||
|
|
|
@ -82,3 +82,12 @@ tuning.
|
|||
"""
|
||||
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||
|
||||
"""
|
||||
NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
|
||||
[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
|
||||
than learning rate don't need tuning.
|
||||
"""
|
||||
NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
||||
|
|
|
@ -74,6 +74,19 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999,
|
|||
end
|
||||
end
|
||||
|
||||
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
vt = zeros(p.x)
|
||||
β1p, β2p = β1, β2
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
@. vt = β2 * vt + (1 - β2) * p.Δ^2
|
||||
@. p.Δ = (β1 * mt + (1 - β1) * p.Δ) / (1 - β1p) / (√(vt / (1 - β2p)) + ϵ) * η
|
||||
β1p *= β1
|
||||
β2p *= β2
|
||||
end
|
||||
end
|
||||
|
||||
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
|
||||
|
||||
function expdecay(p::Param, γ::Real)
|
||||
|
|
|
@ -3,7 +3,7 @@ using Flux.Tracker
|
|||
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
|
||||
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Opt([w′])
|
||||
|
|
Loading…
Reference in New Issue