Flux.jl/src/optimise/optimisers.jl

141 lines
3.0 KiB
Julia
Raw Normal View History

2018-05-31 19:29:59 +00:00
using Flux
using Base: @get!
2017-08-22 21:25:18 +00:00
2018-05-31 19:29:59 +00:00
const ϵ = 1e-8
2018-05-31 19:29:59 +00:00
# TODO: should use weak refs
2017-09-01 21:06:51 +00:00
2018-05-31 19:29:59 +00:00
"""
Descent(η)
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
"""
mutable struct Descent
eta::Float64
2017-09-01 21:06:51 +00:00
end
2018-05-31 19:29:59 +00:00
function update!(o::Descent, x, Δ)
Δ .*= o.eta
2017-09-01 21:06:51 +00:00
end
2018-05-31 19:29:59 +00:00
"""
Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
Gradient descent with learning rate `η` and momentum `ρ`.
"""
mutable struct Momentum
eta::Float64
rho::Float64
2018-09-11 13:00:24 +00:00
velocity::IdDict
2017-09-01 21:06:51 +00:00
end
2018-09-11 13:00:24 +00:00
Momentum(η, ρ = 0.9) = Momentum(η, ρ, IdDict())
2018-05-31 19:29:59 +00:00
function update!(o::Momentum, x, Δ)
η, ρ = o.eta, o.rho
2018-09-11 13:00:24 +00:00
v = get!(o.velocity, x, zero(x))::typeof(x)
2018-05-31 19:29:59 +00:00
@. v = ρ * v - η * Δ
@. Δ = -v
2017-09-01 21:06:51 +00:00
end
2017-08-22 21:25:18 +00:00
2018-05-31 19:29:59 +00:00
"""
Nesterov(eta, ρ = 0.9)
Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
"""
mutable struct Nesterov
eta::Float64
rho::Float64
2018-09-11 13:00:24 +00:00
velocity::IdDict
2017-08-22 21:25:18 +00:00
end
2017-12-04 08:17:05 +00:00
2018-09-11 13:00:24 +00:00
Nesterov(η, ρ = 0.9) = Nesterov(η, ρ, IdDict())
2018-05-31 19:29:59 +00:00
function update!(o::Nesterov, x, Δ)
η, ρ = o.eta, o.rho
2018-09-11 13:00:24 +00:00
v = get!(o.velocity, x, zero(x))::typeof(x)
2018-05-31 19:29:59 +00:00
d = @. ρ^2 * v - (1+ρ) * η * Δ
@. v = ρ*v - η*Δ
@. Δ = -d
2018-04-26 07:37:24 +00:00
end
2018-05-31 19:29:59 +00:00
"""
RMSProp(η = 0.001, ρ = 0.9)
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
optimiser. Parameters other than learning rate don't need tuning. Often a good
choice for recurrent networks.
"""
mutable struct RMSProp
eta::Float64
rho::Float64
2018-09-11 13:00:24 +00:00
acc::IdDict
2017-12-04 08:17:05 +00:00
end
2017-12-08 18:20:53 +00:00
2018-09-11 13:00:24 +00:00
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
2018-05-31 19:29:59 +00:00
function update!(o::RMSProp, x, Δ)
η, ρ = o.eta, o.rho
2018-09-11 13:00:24 +00:00
acc = get!(o.acc, x, zero(x))::typeof(x)
2018-05-31 19:29:59 +00:00
@. acc = ρ * acc + (1 - ρ) * Δ^2
@. Δ *= η / (acc + ϵ)
2018-04-02 19:57:22 +00:00
end
2018-05-31 19:29:59 +00:00
"""
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
2017-10-12 08:31:38 +00:00
2018-05-31 19:29:59 +00:00
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
"""
mutable struct ADAM
eta::Float64
beta::Tuple{Float64,Float64}
2018-09-11 13:00:24 +00:00
state::IdDict
2017-10-12 08:31:38 +00:00
end
2018-09-11 13:00:24 +00:00
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
2018-05-31 19:29:59 +00:00
function update!(o::ADAM, x, Δ)
η, β = o.eta, o.beta
2018-09-11 13:00:24 +00:00
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
2018-05-31 19:29:59 +00:00
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η
o.state[x] = (mt, vt, βp .* β)
2018-09-14 15:02:56 +00:00
return Δ
2017-12-08 18:20:53 +00:00
end
2018-05-31 19:29:59 +00:00
# """
# AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
#
# [AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
# the ∞-norm.
# """
# """
# ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
#
# [ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
# Parameters don't need tuning.
# """
# """
# ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
#
# [ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
# tuning.
# """
# """
# AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
#
# [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
# tuning.
# """
# struct Optimiser
# os::Vector{Any}
# end
# TODO: decay