2018-05-31 19:29:59 +00:00
|
|
|
|
using Flux
|
|
|
|
|
using Base: @get!
|
2018-10-01 00:00:53 +00:00
|
|
|
|
using MacroTools: @forward
|
2017-08-22 21:25:18 +00:00
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
const ϵ = 1e-8
|
2018-07-03 10:11:32 +00:00
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
# TODO: should use weak refs
|
2017-09-01 21:06:51 +00:00
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
"""
|
2019-10-10 11:18:12 +00:00
|
|
|
|
Descent(η)
|
2018-05-31 19:29:59 +00:00
|
|
|
|
|
|
|
|
|
Classic gradient descent optimiser with learning rate `η`.
|
2019-10-09 10:46:11 +00:00
|
|
|
|
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`
|
|
|
|
|
|
|
|
|
|
## Parameters
|
2019-10-10 11:18:12 +00:00
|
|
|
|
- Learning Rate (η): The amount by which the gradients are discounted before updating the weights. Defaults to `0.1`.
|
2019-10-09 10:46:11 +00:00
|
|
|
|
|
|
|
|
|
## Example
|
|
|
|
|
```julia-repl
|
2019-10-10 11:18:12 +00:00
|
|
|
|
opt = Descent() # uses default η (0.1)
|
|
|
|
|
|
|
|
|
|
opt = Descent(0.3) # use provided η
|
2019-10-09 10:46:11 +00:00
|
|
|
|
|
|
|
|
|
ps = params(model)
|
|
|
|
|
|
|
|
|
|
gs = gradient(ps) do
|
|
|
|
|
loss(x, y)
|
|
|
|
|
end
|
|
|
|
|
|
2019-10-22 12:36:39 +00:00
|
|
|
|
Flux.Optimise.update!(opt, ps, gs)
|
2019-10-09 10:46:11 +00:00
|
|
|
|
```
|
2018-05-31 19:29:59 +00:00
|
|
|
|
"""
|
|
|
|
|
mutable struct Descent
|
|
|
|
|
eta::Float64
|
2017-09-01 21:06:51 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-10-05 11:43:03 +00:00
|
|
|
|
Descent() = Descent(0.1)
|
2018-10-31 14:58:55 +00:00
|
|
|
|
|
2019-01-28 13:59:23 +00:00
|
|
|
|
function apply!(o::Descent, x, Δ)
|
2018-05-31 19:29:59 +00:00
|
|
|
|
Δ .*= o.eta
|
2017-09-01 21:06:51 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
"""
|
2019-09-28 10:39:00 +00:00
|
|
|
|
Momentum(η, ρ)
|
2019-10-04 09:16:03 +00:00
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
Gradient descent with learning rate `η` and momentum `ρ`.
|
2019-10-10 11:18:12 +00:00
|
|
|
|
|
|
|
|
|
## Parameters
|
|
|
|
|
- Learning Rate (`η`): Amount by which gradients are discounted before updating the weights. Defaults to `0.01`.
|
|
|
|
|
- Momentum (`ρ`): Parameter that accelerates descent in the relevant direction and dampens oscillations. Defaults to `0.9`.
|
|
|
|
|
|
|
|
|
|
## Examples
|
|
|
|
|
```julia
|
|
|
|
|
opt = Momentum() # uses defaults of η = 0.01 and ρ = 0.9
|
|
|
|
|
|
|
|
|
|
opt = Momentum(0.01, 0.99)
|
|
|
|
|
```
|
2018-05-31 19:29:59 +00:00
|
|
|
|
"""
|
|
|
|
|
mutable struct Momentum
|
|
|
|
|
eta::Float64
|
|
|
|
|
rho::Float64
|
2018-09-11 13:00:24 +00:00
|
|
|
|
velocity::IdDict
|
2017-09-01 21:06:51 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-10-01 00:00:53 +00:00
|
|
|
|
Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
2018-05-31 19:29:59 +00:00
|
|
|
|
|
2019-01-28 13:59:23 +00:00
|
|
|
|
function apply!(o::Momentum, x, Δ)
|
2018-05-31 19:29:59 +00:00
|
|
|
|
η, ρ = o.eta, o.rho
|
2019-03-08 12:13:58 +00:00
|
|
|
|
v = get!(o.velocity, x, zero(x))::typeof(x)
|
2018-05-31 19:29:59 +00:00
|
|
|
|
@. v = ρ * v - η * Δ
|
|
|
|
|
@. Δ = -v
|
2017-09-01 21:06:51 +00:00
|
|
|
|
end
|
2017-08-22 21:25:18 +00:00
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
"""
|
2019-09-28 10:39:00 +00:00
|
|
|
|
Nesterov(η, ρ)
|
2019-10-04 09:16:03 +00:00
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
|
2019-10-10 11:18:12 +00:00
|
|
|
|
|
|
|
|
|
## Parameters
|
|
|
|
|
- Learning Rate (η): Amount by which the gradients are dicsounted berfore updating the weights. Defaults to `0.001`.
|
|
|
|
|
- Nesterov Momentum (ρ): Paramters controlling the amount of nesterov momentum to be applied. Defaults to `0.9`.
|
|
|
|
|
|
|
|
|
|
## Examples
|
|
|
|
|
```julia
|
|
|
|
|
opt = Nesterov() # uses defaults η = 0.001 and ρ = 0.9
|
|
|
|
|
|
|
|
|
|
opt = Nesterov(0.003, 0.95)
|
|
|
|
|
```
|
2018-05-31 19:29:59 +00:00
|
|
|
|
"""
|
|
|
|
|
mutable struct Nesterov
|
|
|
|
|
eta::Float64
|
|
|
|
|
rho::Float64
|
2018-09-11 13:00:24 +00:00
|
|
|
|
velocity::IdDict
|
2017-08-22 21:25:18 +00:00
|
|
|
|
end
|
2017-12-04 08:17:05 +00:00
|
|
|
|
|
2018-10-01 00:00:53 +00:00
|
|
|
|
Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
2018-05-31 19:29:59 +00:00
|
|
|
|
|
2019-01-28 13:59:23 +00:00
|
|
|
|
function apply!(o::Nesterov, x, Δ)
|
2018-05-31 19:29:59 +00:00
|
|
|
|
η, ρ = o.eta, o.rho
|
2019-03-08 12:13:58 +00:00
|
|
|
|
v = get!(o.velocity, x, zero(x))::typeof(x)
|
2018-05-31 19:29:59 +00:00
|
|
|
|
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
|
|
|
|
@. v = ρ*v - η*Δ
|
|
|
|
|
@. Δ = -d
|
2018-04-26 07:37:24 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
"""
|
2019-09-28 10:39:00 +00:00
|
|
|
|
RMSProp(η, ρ)
|
2019-10-04 09:16:03 +00:00
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
Implements the RMSProp algortihm. Often a good choice for recurrent networks. Paramters other than learning rate generally don't need tuning.
|
|
|
|
|
|
|
|
|
|
## Parameters
|
|
|
|
|
- Learning Rate (η): Defaults to `0.001`.
|
|
|
|
|
- Rho (ρ): Defaults to `0.9`.
|
|
|
|
|
|
|
|
|
|
## Examples
|
|
|
|
|
```julia
|
|
|
|
|
opt = RMSProp() # uses default η = 0.001 and ρ = 0.9
|
2018-05-31 19:29:59 +00:00
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
opt = RMSProp(0.002, 0.95)
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
## References
|
2019-04-25 11:04:03 +00:00
|
|
|
|
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
2018-05-31 19:29:59 +00:00
|
|
|
|
"""
|
|
|
|
|
mutable struct RMSProp
|
|
|
|
|
eta::Float64
|
|
|
|
|
rho::Float64
|
2018-09-11 13:00:24 +00:00
|
|
|
|
acc::IdDict
|
2017-12-04 08:17:05 +00:00
|
|
|
|
end
|
2017-12-08 18:20:53 +00:00
|
|
|
|
|
2018-09-11 13:00:24 +00:00
|
|
|
|
RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
2018-05-31 19:29:59 +00:00
|
|
|
|
|
2019-01-28 13:59:23 +00:00
|
|
|
|
function apply!(o::RMSProp, x, Δ)
|
2018-05-31 19:29:59 +00:00
|
|
|
|
η, ρ = o.eta, o.rho
|
2019-03-08 12:13:58 +00:00
|
|
|
|
acc = get!(o.acc, x, zero(x))::typeof(x)
|
2018-05-31 19:29:59 +00:00
|
|
|
|
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
|
|
|
|
@. Δ *= η / (√acc + ϵ)
|
2018-04-02 19:57:22 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-05-31 19:29:59 +00:00
|
|
|
|
"""
|
2019-10-04 09:16:03 +00:00
|
|
|
|
ADAM(η, β::Tuple)
|
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
Implements the ADAM optimiser.
|
|
|
|
|
|
|
|
|
|
## Paramters
|
|
|
|
|
- Learning Rate (`η`): Defaults to `0.001`.
|
|
|
|
|
- Beta (`β::Tuple`): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
2017-10-12 08:31:38 +00:00
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
## Examples
|
|
|
|
|
|
|
|
|
|
```julia
|
|
|
|
|
opt = ADAM() # uses the default η = 0.001 and β = (0.9, 0.999)
|
|
|
|
|
|
|
|
|
|
opt = ADAM(0.001, (0.9, 0.8))
|
|
|
|
|
```
|
|
|
|
|
## References
|
2018-05-31 19:29:59 +00:00
|
|
|
|
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
|
|
|
|
"""
|
|
|
|
|
mutable struct ADAM
|
|
|
|
|
eta::Float64
|
|
|
|
|
beta::Tuple{Float64,Float64}
|
2018-09-11 13:00:24 +00:00
|
|
|
|
state::IdDict
|
2017-10-12 08:31:38 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-09-11 13:00:24 +00:00
|
|
|
|
ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
|
2018-05-31 19:29:59 +00:00
|
|
|
|
|
2019-01-28 13:59:23 +00:00
|
|
|
|
function apply!(o::ADAM, x, Δ)
|
2018-05-31 19:29:59 +00:00
|
|
|
|
η, β = o.eta, o.beta
|
2018-09-11 13:00:24 +00:00
|
|
|
|
mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
|
2018-05-31 19:29:59 +00:00
|
|
|
|
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
|
|
|
|
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
|
|
|
|
@. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η
|
|
|
|
|
o.state[x] = (mt, vt, βp .* β)
|
2018-09-14 15:02:56 +00:00
|
|
|
|
return Δ
|
2017-12-08 18:20:53 +00:00
|
|
|
|
end
|
2018-05-31 19:29:59 +00:00
|
|
|
|
|
2019-08-19 04:22:32 +00:00
|
|
|
|
"""
|
2019-10-04 09:16:03 +00:00
|
|
|
|
RADAM(η, β::Tuple)
|
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
Implements the rectified ADAM optimizer.
|
2019-08-19 04:22:32 +00:00
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
## Parameters
|
|
|
|
|
- Learning Rate (η): Defaults to `0.001`
|
|
|
|
|
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
|
|
|
|
|
|
|
|
|
## Examples
|
|
|
|
|
|
|
|
|
|
```julia
|
|
|
|
|
opt = RADAM() # uses the default η = 0.001 and β = (0.9, 0.999)
|
|
|
|
|
|
|
|
|
|
opt = RADAM(0.001, (0.9, 0.8))
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
## References
|
2019-08-19 04:22:32 +00:00
|
|
|
|
[RADAM](https://arxiv.org/pdf/1908.03265v1.pdf) optimiser (Rectified ADAM).
|
|
|
|
|
"""
|
|
|
|
|
mutable struct RADAM
|
|
|
|
|
eta::Float64
|
|
|
|
|
beta::Tuple{Float64,Float64}
|
|
|
|
|
state::IdDict
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
RADAM(η = 0.001, β = (0.9, 0.999)) = RADAM(η, β, IdDict())
|
|
|
|
|
|
|
|
|
|
function apply!(o::RADAM, x, Δ)
|
|
|
|
|
η, β = o.eta, o.beta
|
|
|
|
|
ρ∞ = 2/(1-β[2])-1
|
|
|
|
|
mt, vt, βp, t = get!(o.state, x, (zero(x), zero(x), β, 1))
|
|
|
|
|
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
|
|
|
|
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
|
|
|
|
ρ = ρ∞ - 2t*βp[2]/(1-βp[2])
|
|
|
|
|
if ρ > 4
|
|
|
|
|
r = sqrt((ρ-4)*(ρ-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ))
|
|
|
|
|
@. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η * r
|
|
|
|
|
else
|
|
|
|
|
@. Δ = mt / (1 - βp[1]) * η
|
|
|
|
|
end
|
|
|
|
|
o.state[x] = (mt, vt, βp .* β, t+1)
|
|
|
|
|
return Δ
|
|
|
|
|
end
|
|
|
|
|
|
2018-09-16 12:04:51 +00:00
|
|
|
|
"""
|
2019-10-04 09:16:03 +00:00
|
|
|
|
AdaMax(η, β::Tuple)
|
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
Variant of ADAM based on ∞-norm.
|
|
|
|
|
|
|
|
|
|
## Parameters
|
|
|
|
|
- Learning Rate (η): Defaults to `0.001`
|
|
|
|
|
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
2018-09-16 12:04:51 +00:00
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
## Examples
|
|
|
|
|
```julia
|
2019-10-22 12:36:39 +00:00
|
|
|
|
opt = AdaMax() # uses default η and β
|
2019-10-10 11:18:12 +00:00
|
|
|
|
|
|
|
|
|
opt = AdaMax(0.001, (0.9, 0.995))
|
|
|
|
|
```
|
|
|
|
|
## References
|
|
|
|
|
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser.
|
2018-09-16 12:04:51 +00:00
|
|
|
|
"""
|
|
|
|
|
mutable struct AdaMax
|
|
|
|
|
eta::Float64
|
|
|
|
|
beta::Tuple{Float64,Float64}
|
|
|
|
|
state::IdDict
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
|
|
|
|
|
|
2019-01-28 13:59:23 +00:00
|
|
|
|
function apply!(o::AdaMax, x, Δ)
|
2018-09-16 12:04:51 +00:00
|
|
|
|
η, β = o.eta, o.beta
|
|
|
|
|
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
|
|
|
|
|
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
|
|
|
|
@. ut = max(β[2] * ut, abs(Δ))
|
|
|
|
|
@. Δ = (η/(1 - βp[1])) * mt/(ut + ϵ)
|
|
|
|
|
o.state[x] = (mt, ut, βp .* β)
|
|
|
|
|
return Δ
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
"""
|
2019-09-28 10:39:00 +00:00
|
|
|
|
ADAGrad(η)
|
2019-10-04 09:16:03 +00:00
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
Implements AdaGrad. It has parameter specific learning rates based on how frequently it is updated.
|
|
|
|
|
|
|
|
|
|
## Parameters
|
|
|
|
|
- Learning Rate (η): Defaults to `0.1`
|
|
|
|
|
|
|
|
|
|
## Examples
|
|
|
|
|
```julia
|
|
|
|
|
opt = ADAGrad() # uses default η = 0.1
|
|
|
|
|
|
|
|
|
|
opt = ADAGrad(0.001)
|
|
|
|
|
```
|
2018-09-16 12:04:51 +00:00
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
## References
|
2018-09-16 12:04:51 +00:00
|
|
|
|
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
|
|
|
|
|
Parameters don't need tuning.
|
|
|
|
|
"""
|
|
|
|
|
mutable struct ADAGrad
|
|
|
|
|
eta::Float64
|
|
|
|
|
acc::IdDict
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
|
|
|
|
|
2019-01-28 13:59:23 +00:00
|
|
|
|
function apply!(o::ADAGrad, x, Δ)
|
2018-09-16 12:04:51 +00:00
|
|
|
|
η = o.eta
|
2019-11-19 08:31:04 +00:00
|
|
|
|
acc = get!(o.acc, x, fill!(zero(x), ϵ))::typeof(x)
|
2018-09-16 12:04:51 +00:00
|
|
|
|
@. acc += Δ^2
|
2018-11-02 11:59:04 +00:00
|
|
|
|
@. Δ *= η / (√acc + ϵ)
|
2018-09-16 12:04:51 +00:00
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
"""
|
2019-09-28 10:39:00 +00:00
|
|
|
|
ADADelta(ρ)
|
2019-10-04 09:16:03 +00:00
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
Version of ADAGrad that adapts learning rate based on a window of past gradient updates. Parameters don't need tuning.
|
|
|
|
|
|
|
|
|
|
## Parameters
|
|
|
|
|
- Rho (ρ): Factor by which gradient is decayed at each time step. Defaults to `0.9`.
|
2018-09-16 12:04:51 +00:00
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
## Examples
|
|
|
|
|
```julia
|
|
|
|
|
opt = ADADelta() # uses default ρ = 0.9
|
|
|
|
|
opt = ADADelta(0.89)
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
## References
|
|
|
|
|
[ADADelta](https://arxiv.org/abs/1212.5701) optimiser.
|
2018-09-16 12:04:51 +00:00
|
|
|
|
"""
|
|
|
|
|
mutable struct ADADelta
|
|
|
|
|
rho::Float64
|
|
|
|
|
state::IdDict
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
|
|
|
|
|
|
2019-01-28 13:59:23 +00:00
|
|
|
|
function apply!(o::ADADelta, x, Δ)
|
2018-09-16 12:04:51 +00:00
|
|
|
|
ρ = o.rho
|
|
|
|
|
acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
|
|
|
|
|
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
2018-11-02 11:59:04 +00:00
|
|
|
|
@. Δ *= √Δacc/ (√acc + ϵ)
|
2018-09-16 12:04:51 +00:00
|
|
|
|
@. Δacc = ρ * Δacc + (1 - ρ) * Δ^2
|
|
|
|
|
return Δ
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
"""
|
2019-10-04 09:16:03 +00:00
|
|
|
|
AMSGrad(η, β::Tuple)
|
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
Implements AMSGrad version of the ADAM optimiser. Parameters don't need tuning.
|
|
|
|
|
|
|
|
|
|
## Parameters
|
|
|
|
|
- Learning Rate (η): Defaults to `0.001`.
|
|
|
|
|
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
2018-09-16 12:04:51 +00:00
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
## Examples
|
|
|
|
|
```julia
|
|
|
|
|
opt = AMSGrad() # uses default η and β
|
|
|
|
|
opt = AMSGrad(0.001, (0.89, 0.995))
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
## References
|
|
|
|
|
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser.
|
2018-09-16 12:04:51 +00:00
|
|
|
|
"""
|
|
|
|
|
mutable struct AMSGrad
|
|
|
|
|
eta::Float64
|
|
|
|
|
beta::Tuple{Float64, Float64}
|
|
|
|
|
state::IdDict
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
|
|
|
|
|
|
2019-01-28 13:59:23 +00:00
|
|
|
|
function apply!(o::AMSGrad, x, Δ)
|
2018-09-16 12:04:51 +00:00
|
|
|
|
η, β = o.eta, o.beta
|
2019-11-19 08:27:44 +00:00
|
|
|
|
mt, vt, v̂t = get!(o.state, x, (fill!(zero(x), ϵ), fill!(zero(x), ϵ), fill!(zero(x), ϵ)))
|
2018-09-16 12:04:51 +00:00
|
|
|
|
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
|
|
|
|
@. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2
|
2019-11-19 08:27:44 +00:00
|
|
|
|
@. v̂t = max(v̂t, vt)
|
2018-11-02 11:59:04 +00:00
|
|
|
|
@. Δ = η * mt / (√v̂t + ϵ)
|
2018-09-16 12:04:51 +00:00
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
"""
|
2019-10-04 09:16:03 +00:00
|
|
|
|
NADAM(η, β::Tuple)
|
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
Nesterov variant of ADAM. Parameters don't need tuning.
|
2018-09-16 12:04:51 +00:00
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
## Parameters
|
|
|
|
|
- Learning Rate (η): Defaults to `0.001`.
|
|
|
|
|
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
|
|
|
|
|
|
|
|
|
## Examples
|
|
|
|
|
```julia
|
|
|
|
|
opt = NADAM() # uses default η and β
|
|
|
|
|
opt = NADAM(0.002, (0.89, 0.995))
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
## References
|
|
|
|
|
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser.
|
2018-09-16 12:04:51 +00:00
|
|
|
|
"""
|
|
|
|
|
mutable struct NADAM
|
|
|
|
|
eta::Float64
|
|
|
|
|
beta::Tuple{Float64, Float64}
|
|
|
|
|
state::IdDict
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
|
|
|
|
|
|
2019-01-28 13:59:23 +00:00
|
|
|
|
function apply!(o::NADAM, x, Δ)
|
2018-09-16 12:04:51 +00:00
|
|
|
|
η, β = o.eta, o.beta
|
2019-06-16 13:36:59 +00:00
|
|
|
|
mt, vt, (β1p, β2p) = get!(o.state, x, (zero(x), zero(x), o.beta))
|
2018-09-16 12:04:51 +00:00
|
|
|
|
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
|
|
|
|
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
2018-11-02 11:59:04 +00:00
|
|
|
|
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / (√(vt * β[2] / (1 - β2p)) + ϵ) * η
|
2018-09-16 12:04:51 +00:00
|
|
|
|
o.state[x] = (mt, vt, (β1p * β[1], β2p * β[2]))
|
|
|
|
|
return Δ
|
|
|
|
|
end
|
|
|
|
|
|
2018-10-01 00:00:53 +00:00
|
|
|
|
"""
|
2019-10-04 09:16:03 +00:00
|
|
|
|
ADAMW(η, β::Tuple, decay)
|
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
Variant of ADAM defined by fixing weight decay regularization.
|
|
|
|
|
|
|
|
|
|
## Parameters
|
|
|
|
|
- Learning Rate (η): Defaults to `0.001`.
|
|
|
|
|
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to (0.9, 0.999).
|
|
|
|
|
- decay: Decay applied to weights during optimisation. Defaults to 0.
|
|
|
|
|
|
|
|
|
|
## Examples
|
|
|
|
|
```julia
|
|
|
|
|
opt = ADAMW() # uses default η, β and decay
|
2019-10-22 12:36:39 +00:00
|
|
|
|
opt = ADAMW(0.001, (0.89, 0.995), 0.1)
|
2019-10-10 11:18:12 +00:00
|
|
|
|
```
|
2018-10-01 00:00:53 +00:00
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
## References
|
|
|
|
|
[ADAMW](https://arxiv.org/abs/1711.05101)
|
2018-10-01 00:00:53 +00:00
|
|
|
|
"""
|
2018-10-31 14:58:55 +00:00
|
|
|
|
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
|
2018-12-12 11:17:42 +00:00
|
|
|
|
Optimiser(ADAM(η, β), WeightDecay(decay))
|
2018-10-01 00:00:53 +00:00
|
|
|
|
|
|
|
|
|
# Compose optimizers
|
|
|
|
|
|
|
|
|
|
"""
|
2018-10-11 04:37:16 +00:00
|
|
|
|
Optimiser(a, b, c...)
|
2018-10-31 14:58:55 +00:00
|
|
|
|
|
2018-10-05 11:57:03 +00:00
|
|
|
|
Combine several optimisers into one; each optimiser produces a modified gradient
|
|
|
|
|
that will be fed into the next, and this is finally applied to the parameter as
|
|
|
|
|
usual.
|
2018-10-01 00:00:53 +00:00
|
|
|
|
"""
|
2018-10-11 04:37:16 +00:00
|
|
|
|
mutable struct Optimiser
|
2018-09-16 12:04:51 +00:00
|
|
|
|
os::Vector{Any}
|
|
|
|
|
end
|
|
|
|
|
|
2018-10-11 04:37:16 +00:00
|
|
|
|
Optimiser(o...) = Optimiser(Any[o...])
|
2018-10-01 00:00:53 +00:00
|
|
|
|
|
2018-10-11 04:37:16 +00:00
|
|
|
|
@forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
|
|
|
|
|
@forward Optimiser.os Base.iterate
|
2018-10-01 00:00:53 +00:00
|
|
|
|
|
2018-10-11 04:37:16 +00:00
|
|
|
|
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
|
|
|
|
|
|
2019-01-28 13:59:23 +00:00
|
|
|
|
function apply!(o::Optimiser, x, Δ)
|
2018-09-16 12:04:51 +00:00
|
|
|
|
for opt in o.os
|
2019-01-28 13:59:23 +00:00
|
|
|
|
Δ = apply!(opt, x, Δ)
|
2018-09-16 12:04:51 +00:00
|
|
|
|
end
|
|
|
|
|
return Δ
|
|
|
|
|
end
|
2018-05-31 19:29:59 +00:00
|
|
|
|
|
2018-11-12 13:47:10 +00:00
|
|
|
|
"""
|
2019-10-04 09:16:03 +00:00
|
|
|
|
InvDecay(γ)
|
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
Applies inverse time decay to an optimiser
|
2018-11-12 13:47:10 +00:00
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
## Parameters
|
|
|
|
|
- gamma (γ): Defaults to `0.001`
|
|
|
|
|
|
|
|
|
|
## Example
|
2018-11-12 13:47:10 +00:00
|
|
|
|
```julia
|
|
|
|
|
Optimiser(InvDecay(..), Opt(..))
|
|
|
|
|
```
|
|
|
|
|
"""
|
2018-09-16 12:04:51 +00:00
|
|
|
|
mutable struct InvDecay
|
|
|
|
|
gamma::Float64
|
2018-10-27 13:56:42 +00:00
|
|
|
|
state::IdDict
|
2018-09-16 12:04:51 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-10-27 13:56:42 +00:00
|
|
|
|
InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
|
2018-09-16 12:04:51 +00:00
|
|
|
|
|
2019-01-28 13:59:23 +00:00
|
|
|
|
function apply!(o::InvDecay, x, Δ)
|
2018-10-27 13:56:42 +00:00
|
|
|
|
γ = o.gamma
|
|
|
|
|
n = get!(o.state, x, 1)
|
2018-09-16 12:04:51 +00:00
|
|
|
|
Δ .*= 1 / (1 + γ * n)
|
2018-10-27 13:56:42 +00:00
|
|
|
|
o.state[x] = n + 1
|
2018-09-16 12:04:51 +00:00
|
|
|
|
return Δ
|
|
|
|
|
end
|
|
|
|
|
|
2018-11-12 13:47:10 +00:00
|
|
|
|
"""
|
2019-10-04 09:16:03 +00:00
|
|
|
|
ExpDecay(eta, decay, decay_step, clip)
|
|
|
|
|
|
|
|
|
|
Discount the learning rate `eta` by `decay` every `decay_step` till a minimum of `clip`.
|
2018-11-12 13:47:10 +00:00
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
## Parameters
|
|
|
|
|
- Learning Rate (eta): Defaults to `0.001`.
|
|
|
|
|
- decay: Factor by which the learning rate is discounted. Defaults to `0.1`.
|
|
|
|
|
- decay_step: Schedules decay operations by setting number of steps between two decay operations. Defaults to `1000`.
|
|
|
|
|
- clip: Minimum value of learning rate. Defaults to `1e-4`.
|
|
|
|
|
|
|
|
|
|
## Example
|
2018-11-12 13:47:10 +00:00
|
|
|
|
To apply exponential decay to an optimiser:
|
|
|
|
|
```julia
|
|
|
|
|
Optimiser(ExpDecay(..), Opt(..))
|
2019-10-10 11:18:12 +00:00
|
|
|
|
|
|
|
|
|
opt = Optimiser(ExpDecay(), ADAM())
|
2018-11-12 13:47:10 +00:00
|
|
|
|
```
|
|
|
|
|
"""
|
2018-09-16 12:04:51 +00:00
|
|
|
|
mutable struct ExpDecay
|
2018-10-29 17:42:24 +00:00
|
|
|
|
eta::Float64
|
2018-10-27 13:56:42 +00:00
|
|
|
|
decay::Float64
|
|
|
|
|
step::Int64
|
|
|
|
|
clip::Float64
|
|
|
|
|
current::IdDict
|
2018-09-16 12:04:51 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-10-29 17:42:24 +00:00
|
|
|
|
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
|
2018-09-16 12:04:51 +00:00
|
|
|
|
|
2019-01-28 13:59:23 +00:00
|
|
|
|
function apply!(o::ExpDecay, x, Δ)
|
2018-10-29 17:42:24 +00:00
|
|
|
|
η, s, decay = o.eta, o.step, o.decay
|
2018-10-27 13:56:42 +00:00
|
|
|
|
n = o.current[x] = get(o.current, x, 0) + 1
|
2018-10-31 14:58:55 +00:00
|
|
|
|
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
|
2018-10-27 13:56:42 +00:00
|
|
|
|
η = max(η * decay^(s / n), o.clip)
|
2018-10-29 17:42:24 +00:00
|
|
|
|
o.eta = η
|
2018-10-27 13:56:42 +00:00
|
|
|
|
end
|
2019-04-11 11:58:06 +00:00
|
|
|
|
@. Δ *= η
|
2018-09-16 12:04:51 +00:00
|
|
|
|
end
|
2018-10-01 00:00:53 +00:00
|
|
|
|
|
2018-11-12 13:47:10 +00:00
|
|
|
|
"""
|
2019-10-04 09:16:03 +00:00
|
|
|
|
WeightDecay(wd)
|
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
Decays the weight by `wd`
|
2018-11-12 13:47:10 +00:00
|
|
|
|
|
2019-10-10 11:18:12 +00:00
|
|
|
|
## Parameters
|
|
|
|
|
- weight decay (wd): 0
|
2018-11-12 13:47:10 +00:00
|
|
|
|
"""
|
2018-10-11 04:37:16 +00:00
|
|
|
|
mutable struct WeightDecay
|
|
|
|
|
wd::Real
|
2018-10-01 00:00:53 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-10-27 13:56:42 +00:00
|
|
|
|
WeightDecay() = WeightDecay(0)
|
2018-10-31 14:58:55 +00:00
|
|
|
|
|
2019-02-28 14:58:42 +00:00
|
|
|
|
function apply!(o::WeightDecay, x, Δ)
|
2018-10-27 13:56:42 +00:00
|
|
|
|
wd = o.wd
|
2019-03-08 12:13:58 +00:00
|
|
|
|
@. Δ += wd * x
|
2018-10-01 00:00:53 +00:00
|
|
|
|
end
|