2018-05-31 19:29:59 +00:00
|
|
|
|
using Flux
|
|
|
|
|
using Base: @get!
|
2017-08-22 21:25:18 +00:00
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
const ϵ = 1e-8
|
2018-07-03 10:11:32 +00:00
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
# TODO: should use weak refs
|
2017-09-01 21:06:51 +00:00
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
"""
|
|
|
|
|
Descent(η)
|
|
|
|
|
|
|
|
|
|
Classic gradient descent optimiser with learning rate `η`.
|
|
|
|
|
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
|
|
|
|
|
"""
|
|
|
|
|
mutable struct Descent
|
|
|
|
|
eta::Float64
|
2017-09-01 21:06:51 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
function update!(o::Descent, x, Δ)
|
|
|
|
|
Δ .*= o.eta
|
2017-09-01 21:06:51 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
"""
|
|
|
|
|
Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
|
|
|
|
|
|
|
|
|
|
Gradient descent with learning rate `η` and momentum `ρ`.
|
|
|
|
|
"""
|
|
|
|
|
mutable struct Momentum
|
|
|
|
|
eta::Float64
|
|
|
|
|
rho::Float64
|
2018-09-11 13:00:24 +00:00
|
|
|
|
velocity::IdDict
|
2017-09-01 21:06:51 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-09-11 13:00:24 +00:00
|
|
|
|
Momentum(η, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
2018-05-31 19:29:59 +00:00
|
|
|
|
|
|
|
|
|
function update!(o::Momentum, x, Δ)
|
|
|
|
|
η, ρ = o.eta, o.rho
|
2018-09-11 13:00:24 +00:00
|
|
|
|
v = get!(o.velocity, x, zero(x))::typeof(x)
|
2018-05-31 19:29:59 +00:00
|
|
|
|
@. v = ρ * v - η * Δ
|
|
|
|
|
@. Δ = -v
|
2017-09-01 21:06:51 +00:00
|
|
|
|
end
|
2017-08-22 21:25:18 +00:00
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
"""
|
|
|
|
|
Nesterov(eta, ρ = 0.9)
|
|
|
|
|
|
|
|
|
|
Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
|
|
|
|
|
"""
|
|
|
|
|
mutable struct Nesterov
|
|
|
|
|
eta::Float64
|
|
|
|
|
rho::Float64
|
2018-09-11 13:00:24 +00:00
|
|
|
|
velocity::IdDict
|
2017-08-22 21:25:18 +00:00
|
|
|
|
end
|
2017-12-04 08:17:05 +00:00
|
|
|
|
|
2018-09-11 13:00:24 +00:00
|
|
|
|
Nesterov(η, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
2018-05-31 19:29:59 +00:00
|
|
|
|
|
|
|
|
|
function update!(o::Nesterov, x, Δ)
|
|
|
|
|
η, ρ = o.eta, o.rho
|
2018-09-11 13:00:24 +00:00
|
|
|
|
v = get!(o.velocity, x, zero(x))::typeof(x)
|
2018-05-31 19:29:59 +00:00
|
|
|
|
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
|
|
|
|
@. v = ρ*v - η*Δ
|
|
|
|
|
@. Δ = -d
|
2018-04-26 07:37:24 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
"""
|
|
|
|
|
RMSProp(η = 0.001, ρ = 0.9)
|
|
|
|
|
|
|
|
|
|
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
|
|
|
|
optimiser. Parameters other than learning rate don't need tuning. Often a good
|
|
|
|
|
choice for recurrent networks.
|
|
|
|
|
"""
|
|
|
|
|
mutable struct RMSProp
|
|
|
|
|
eta::Float64
|
|
|
|
|
rho::Float64
|
2018-09-11 13:00:24 +00:00
|
|
|
|
acc::IdDict
|
2017-12-04 08:17:05 +00:00
|
|
|
|
end
|
2017-12-08 18:20:53 +00:00
|
|
|
|
|
2018-09-11 13:00:24 +00:00
|
|
|
|
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
2018-05-31 19:29:59 +00:00
|
|
|
|
|
|
|
|
|
function update!(o::RMSProp, x, Δ)
|
|
|
|
|
η, ρ = o.eta, o.rho
|
2018-09-11 13:00:24 +00:00
|
|
|
|
acc = get!(o.acc, x, zero(x))::typeof(x)
|
2018-05-31 19:29:59 +00:00
|
|
|
|
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
|
|
|
|
@. Δ *= η / (√acc + ϵ)
|
2018-04-02 19:57:22 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
"""
|
|
|
|
|
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
2017-10-12 08:31:38 +00:00
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
|
|
|
|
"""
|
|
|
|
|
mutable struct ADAM
|
|
|
|
|
eta::Float64
|
|
|
|
|
beta::Tuple{Float64,Float64}
|
2018-09-11 13:00:24 +00:00
|
|
|
|
state::IdDict
|
2017-10-12 08:31:38 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-09-11 13:00:24 +00:00
|
|
|
|
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
|
2018-05-31 19:29:59 +00:00
|
|
|
|
|
|
|
|
|
function update!(o::ADAM, x, Δ)
|
|
|
|
|
η, β = o.eta, o.beta
|
2018-09-11 13:00:24 +00:00
|
|
|
|
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
|
2018-05-31 19:29:59 +00:00
|
|
|
|
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
|
|
|
|
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
|
|
|
|
@. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η
|
|
|
|
|
o.state[x] = (mt, vt, βp .* β)
|
2018-09-14 15:02:56 +00:00
|
|
|
|
return Δ
|
2017-12-08 18:20:53 +00:00
|
|
|
|
end
|
2018-05-31 19:29:59 +00:00
|
|
|
|
|
|
|
|
|
# """
|
|
|
|
|
# AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
|
|
|
|
#
|
|
|
|
|
# [AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
|
|
|
|
|
# the ∞-norm.
|
|
|
|
|
# """
|
|
|
|
|
|
|
|
|
|
# """
|
|
|
|
|
# ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
|
|
|
|
|
#
|
|
|
|
|
# [ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
|
|
|
|
|
# Parameters don't need tuning.
|
|
|
|
|
# """
|
|
|
|
|
|
|
|
|
|
# """
|
|
|
|
|
# ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
|
|
|
|
|
#
|
|
|
|
|
# [ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
|
|
|
|
|
# tuning.
|
|
|
|
|
# """
|
|
|
|
|
|
|
|
|
|
# """
|
|
|
|
|
# AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
|
|
|
|
#
|
|
|
|
|
# [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
|
|
|
|
|
# tuning.
|
|
|
|
|
# """
|
|
|
|
|
|
|
|
|
|
# struct Optimiser
|
|
|
|
|
# os::Vector{Any}
|
|
|
|
|
# end
|
|
|
|
|
|
|
|
|
|
# TODO: decay
|