Implement AMSGrad optimiser
This commit is contained in:
parent
9f5c4dd3e9
commit
36001d085a
@ -1,7 +1,7 @@
|
|||||||
module Optimise
|
module Optimise
|
||||||
|
|
||||||
export update!, params, train!,
|
export update!, params, train!,
|
||||||
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta
|
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
|
||||||
|
|
||||||
struct Param{T}
|
struct Param{T}
|
||||||
x::T
|
x::T
|
||||||
|
@ -71,3 +71,12 @@ tuning.
|
|||||||
"""
|
"""
|
||||||
ADADelta(ps; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) =
|
ADADelta(ps; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) =
|
||||||
optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||||
|
|
||||||
|
"""
|
||||||
|
AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||||
|
|
||||||
|
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
|
||||||
|
tuning.
|
||||||
|
"""
|
||||||
|
AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||||
|
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||||
|
@ -67,8 +67,20 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
|
|||||||
function ()
|
function ()
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||||
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
|
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
|
||||||
@. p.Δ = √(1 - β2p) / √(1 - β1p) * mt / √vt * η
|
@. p.Δ = √(1 - β2p) / (1 - β1p) * mt / √vt * η
|
||||||
β1p *= β1
|
β1p *= β1
|
||||||
β2p *= β2
|
β2p *= β2
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||||
|
mt = zeros(p.x)
|
||||||
|
vt = zeros(p.x) .+ ϵ
|
||||||
|
v̂t = zeros(p.x) .+ ϵ
|
||||||
|
function ()
|
||||||
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||||
|
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
|
||||||
|
@. v̂t = max.(v̂t, vt)
|
||||||
|
@. p.Δ = η * mt / √v̂t
|
||||||
|
end
|
||||||
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user