diff --git a/src/Flux.jl b/src/Flux.jl index 7125630f..7d1d66e6 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -23,7 +23,7 @@ include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs export SGD, ADAM, AdaMax, Momentum, Nesterov, - RMSProp, ADAGrad, ADADelta, AMSGrad + RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM include("utils.jl") include("onehot.jl") diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 5d5d9ea0..0c541b93 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,8 @@ module Optimise export train!, - SGD, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad + SGD, ADAM, AdaMax, Momentum, Nesterov, + RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM struct Param{T} x::T diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl index 29068983..3a07f6ce 100644 --- a/src/optimise/interface.jl +++ b/src/optimise/interface.jl @@ -91,3 +91,12 @@ tuning. """ AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) + +""" + NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + +[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other +than learning rate don't need tuning. +""" +NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = + optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 29b058ba..e3a4ed34 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -27,7 +27,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) acc = zeros(p.x) function () @. acc = ρ * acc + (1 - ρ) * p.Δ^2 - @. p.Δ *= η / (√acc + ϵ) + @. p.Δ *= η / √(acc + ϵ) end end @@ -35,7 +35,7 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) acc = zeros(p.x) .+ ϵ function () @. acc += p.Δ^2 - @. p.Δ *= η / √acc + @. p.Δ *= η / √(acc + ϵ) end end @@ -56,7 +56,7 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ function () @. mt = β1 * mt + (1 - β1) * p.Δ @. vt = β2 * vt + (1 - β2) * p.Δ^2 - @. p.Δ = mt / (1 - β1p) / (√(vt / (1 - β2p)) + ϵ) * η + @. p.Δ = mt / (1 - β1p) / √(vt / (1 - β2p) + ϵ) * η β1p *= β1 β2p *= β2 end @@ -86,6 +86,19 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, end end +function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) + mt = zeros(p.x) + vt = zeros(p.x) + β1p, β2p = β1, β2 + function () + @. mt = β1 * mt + (1 - β1) * p.Δ + @. vt = β2 * vt + (1 - β2) * p.Δ^2 + @. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / √(vt * β2 / (1 - β2p) + ϵ) * η + β1p *= β1 + β2p *= β2 + end +end + clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh) function expdecay(p::Param, γ::Real) diff --git a/test/optimise.jl b/test/optimise.jl index ae7ec8fe..c896bb39 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -3,7 +3,7 @@ using Flux.Tracker @testset "Optimise" begin w = randn(10, 10) - @testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad] + @testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM] w′ = param(randn(10, 10)) loss(x) = Flux.mse(w*x, w′*x) opt = Opt([w′])