Flux.jl/src/optimise/optimisers.jl

98 lines
2.1 KiB
Julia
Raw Normal View History

2017-09-01 21:06:51 +00:00
function descent(p::Param, η::Real)
2017-09-12 13:11:03 +00:00
function ()
2017-10-12 08:31:38 +00:00
@. p.x -= η * p.Δ
@. p.Δ = 0
2017-09-12 13:11:03 +00:00
end
2017-08-22 21:25:18 +00:00
end
2017-10-12 08:31:38 +00:00
function momentum(p::Param, ρ, η)
v = zeros(p.x)
2017-09-01 21:06:51 +00:00
function ()
2017-10-12 08:31:38 +00:00
@. v = ρ * v - η * p.Δ
@. p.Δ = -v
2017-09-01 21:06:51 +00:00
end
end
2017-10-12 08:31:38 +00:00
# Ref. https://arxiv.org/pdf/1212.0901.pdf
function nesterov(p::Param, ρ, η)
v = zeros(p.x)
2017-09-01 21:06:51 +00:00
function ()
2017-10-12 08:31:38 +00:00
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
@. v = ρ*v - η*p.Δ
@. p.Δ = -d
2017-09-01 21:06:51 +00:00
end
end
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
2017-10-12 08:31:38 +00:00
acc = zeros(p.x)
2017-09-01 21:06:51 +00:00
function ()
2017-10-12 08:31:38 +00:00
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= η / (acc + ϵ)
2017-09-01 21:06:51 +00:00
end
end
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ
function ()
2017-10-12 08:31:38 +00:00
@. acc += p.Δ^2
2017-11-21 14:25:09 +00:00
@. p.Δ *= η / acc
2017-09-01 21:06:51 +00:00
end
end
2017-10-12 08:31:38 +00:00
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x)
Δacc = zeros(p.x)
2017-09-01 21:06:51 +00:00
function ()
2017-10-12 08:31:38 +00:00
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= (Δacc + ϵ) / (acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2
end
2017-09-01 21:06:51 +00:00
end
2017-08-22 21:25:18 +00:00
2017-09-01 21:06:51 +00:00
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
2017-10-12 08:31:38 +00:00
vt = zeros(p.x)
2017-09-01 21:06:51 +00:00
β1p, β2p = β1, β2
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
2017-10-12 08:31:38 +00:00
@. vt = β2 * vt + (1 - β2) * p.Δ^2
2017-12-08 18:20:53 +00:00
@. p.Δ = mt / (1 - β1p) / ((vt / (1 - β2p)) + ϵ) * η
2017-09-01 21:06:51 +00:00
β1p *= β1
β2p *= β2
2017-08-22 21:25:18 +00:00
end
end
2017-12-04 08:17:05 +00:00
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
vt = zeros(p.x) .+ ϵ
v̂t = zeros(p.x) .+ ϵ
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
@. v̂t = max.(v̂t, vt)
@. p.Δ = η * mt / v̂t
end
end
2017-12-08 18:20:53 +00:00
2017-10-12 08:31:38 +00:00
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
function expdecay(p::Param, γ::Real)
if γ != 0
return () -> p.Δ .+= γ .* p.x
else
return () -> nothing
end
end
function invdecay(p::Param, γ::Real)
if γ != 0
n = 0
return () -> begin
p.Δ .*= 1 / (1 + γ * n)
n += 1
end
else
return () -> nothing
end
2017-12-08 18:20:53 +00:00
end