Merge pull request #245 from freeboson/adamax

Add AdaMax optimizer
This commit is contained in:
Mike J Innes 2018-05-01 11:28:07 +01:00 committed by GitHub
commit ee89a7797e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 23 additions and 2 deletions

View File

@ -1,7 +1,7 @@
module Optimise module Optimise
export update!, params, train!, export update!, params, train!,
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad SGD, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
struct Param{T} struct Param{T}
x::T x::T

View File

@ -56,6 +56,15 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
the -norm.
"""
AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
""" """
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0) ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)

View File

@ -62,6 +62,18 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
end end
end end
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
ut = zeros(p.x)
β1p = β1
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. ut = max(β2 * ut, abs(p.Δ))
@. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ)
β1p *= β1
end
end
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x) mt = zeros(p.x)
vt = zeros(p.x) .+ ϵ vt = zeros(p.x) .+ ϵ

View File

@ -3,7 +3,7 @@ using Flux.Tracker
@testset "Optimise" begin @testset "Optimise" begin
w = randn(10, 10) w = randn(10, 10)
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad] @testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
w = param(randn(10, 10)) w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w]) opt = Opt([w])