Flux.jl/src/optimise/optimisers.jl

131 lines
3.0 KiB
Julia
Raw Normal View History

2017-09-01 21:06:51 +00:00
function descent(p::Param, η::Real)
2017-09-12 13:11:03 +00:00
function ()
2017-10-12 08:31:38 +00:00
@. p.x -= η * p.Δ
@. p.Δ = 0
2017-09-12 13:11:03 +00:00
end
2017-08-22 21:25:18 +00:00
end
# Ref: https://arxiv.org/abs/1711.05101.pdf
function descentweightdecay(p::Param, η::Real, γ::Real)
function ()
@. p.x = p.x - η * (p.Δ + γ * p.x)
@. p.Δ = 0
end
end
2017-10-12 08:31:38 +00:00
function momentum(p::Param, ρ, η)
2018-07-18 07:01:06 +00:00
v = zero(p.x)
2017-09-01 21:06:51 +00:00
function ()
2017-10-12 08:31:38 +00:00
@. v = ρ * v - η * p.Δ
@. p.Δ = -v
2017-09-01 21:06:51 +00:00
end
end
2017-10-12 08:31:38 +00:00
# Ref. https://arxiv.org/pdf/1212.0901.pdf
function nesterov(p::Param, ρ, η)
2018-07-18 07:01:06 +00:00
v = zero(p.x)
2017-09-01 21:06:51 +00:00
function ()
2017-10-12 08:31:38 +00:00
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
@. v = ρ*v - η*p.Δ
@. p.Δ = -d
2017-09-01 21:06:51 +00:00
end
end
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
2018-07-18 07:01:06 +00:00
acc = zero(p.x)
2017-09-01 21:06:51 +00:00
function ()
2017-10-12 08:31:38 +00:00
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
2018-04-04 09:55:20 +00:00
@. p.Δ *= η / (acc + ϵ)
2017-09-01 21:06:51 +00:00
end
end
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
2018-07-18 07:01:06 +00:00
acc = zero(p.x) .+ ϵ
2017-09-01 21:06:51 +00:00
function ()
2017-10-12 08:31:38 +00:00
@. acc += p.Δ^2
2018-04-04 09:55:20 +00:00
@. p.Δ *= η / (acc + ϵ)
2017-09-01 21:06:51 +00:00
end
end
2017-10-12 08:31:38 +00:00
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
2018-07-18 07:01:06 +00:00
acc = zero(p.x)
Δacc = zero(p.x)
2017-09-01 21:06:51 +00:00
function ()
2017-10-12 08:31:38 +00:00
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= (Δacc + ϵ) / (acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2
end
2017-09-01 21:06:51 +00:00
end
2017-08-22 21:25:18 +00:00
2017-09-01 21:06:51 +00:00
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
2018-07-18 07:01:06 +00:00
mt = zero(p.x)
vt = zero(p.x)
2017-09-01 21:06:51 +00:00
β1p, β2p = β1, β2
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
2017-10-12 08:31:38 +00:00
@. vt = β2 * vt + (1 - β2) * p.Δ^2
2018-04-04 09:55:20 +00:00
@. p.Δ = mt / (1 - β1p) / (vt / (1 - β2p) + ϵ) * η
2017-09-01 21:06:51 +00:00
β1p *= β1
β2p *= β2
2017-08-22 21:25:18 +00:00
end
end
2017-12-04 08:17:05 +00:00
2018-04-26 07:37:24 +00:00
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
2018-07-18 07:01:06 +00:00
mt = zero(p.x)
ut = zero(p.x)
2018-04-26 07:37:24 +00:00
β1p = β1
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
2018-04-26 11:12:31 +00:00
@. ut = max(β2 * ut, abs(p.Δ))
2018-04-26 07:37:24 +00:00
@. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ)
β1p *= β1
end
end
2017-12-04 08:17:05 +00:00
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
2018-07-18 07:01:06 +00:00
mt = zero(p.x)
vt = zero(p.x) .+ ϵ
v̂t = zero(p.x) .+ ϵ
2017-12-04 08:17:05 +00:00
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
@. v̂t = max.(v̂t, vt)
@. p.Δ = η * mt / v̂t
end
end
2017-12-08 18:20:53 +00:00
2018-04-02 19:57:22 +00:00
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
2018-07-18 07:01:06 +00:00
mt = zero(p.x)
vt = zero(p.x)
2018-04-02 19:57:22 +00:00
β1p, β2p = β1, β2
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ^2
2018-04-04 09:48:44 +00:00
@. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / (vt * β2 / (1 - β2p) + ϵ) * η
2018-04-02 19:57:22 +00:00
β1p *= β1
β2p *= β2
end
end
2017-10-12 08:31:38 +00:00
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
function expdecay(p::Param, γ::Real)
if γ != 0
return () -> p.Δ .+= γ .* p.x
else
return () -> nothing
end
end
function invdecay(p::Param, γ::Real)
if γ != 0
n = 0
return () -> begin
p.Δ .*= 1 / (1 + γ * n)
n += 1
end
else
return () -> nothing
end
2017-12-08 18:20:53 +00:00
end