2017-09-01 21:06:51 +00:00
|
|
|
|
function descent(p::Param, η::Real)
|
2017-09-12 13:11:03 +00:00
|
|
|
|
function ()
|
2017-10-12 08:31:38 +00:00
|
|
|
|
@. p.x -= η * p.Δ
|
|
|
|
|
@. p.Δ = 0
|
2017-09-12 13:11:03 +00:00
|
|
|
|
end
|
2017-08-22 21:25:18 +00:00
|
|
|
|
end
|
|
|
|
|
|
2017-10-12 08:31:38 +00:00
|
|
|
|
function momentum(p::Param, ρ, η)
|
|
|
|
|
v = zeros(p.x)
|
2017-09-01 21:06:51 +00:00
|
|
|
|
function ()
|
2017-10-12 08:31:38 +00:00
|
|
|
|
@. v = ρ * v - η * p.Δ
|
|
|
|
|
@. p.Δ = -v
|
2017-09-01 21:06:51 +00:00
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
2017-10-12 08:31:38 +00:00
|
|
|
|
# Ref. https://arxiv.org/pdf/1212.0901.pdf
|
|
|
|
|
function nesterov(p::Param, ρ, η)
|
|
|
|
|
v = zeros(p.x)
|
2017-09-01 21:06:51 +00:00
|
|
|
|
function ()
|
2017-10-12 08:31:38 +00:00
|
|
|
|
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
|
|
|
|
|
@. v = ρ*v - η*p.Δ
|
|
|
|
|
@. p.Δ = -d
|
2017-09-01 21:06:51 +00:00
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
2017-10-12 08:31:38 +00:00
|
|
|
|
acc = zeros(p.x)
|
2017-09-01 21:06:51 +00:00
|
|
|
|
function ()
|
2017-10-12 08:31:38 +00:00
|
|
|
|
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
|
|
|
|
@. p.Δ *= η / (√acc + ϵ)
|
2017-09-01 21:06:51 +00:00
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
|
|
|
|
acc = zeros(p.x) .+ ϵ
|
|
|
|
|
function ()
|
2017-10-12 08:31:38 +00:00
|
|
|
|
@. acc += p.Δ^2
|
2017-11-21 14:25:09 +00:00
|
|
|
|
@. p.Δ *= η / √acc
|
2017-09-01 21:06:51 +00:00
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
2017-10-12 08:31:38 +00:00
|
|
|
|
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
|
|
|
|
|
acc = zeros(p.x)
|
|
|
|
|
Δacc = zeros(p.x)
|
2017-09-01 21:06:51 +00:00
|
|
|
|
function ()
|
2017-10-12 08:31:38 +00:00
|
|
|
|
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
|
|
|
|
@. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ)
|
|
|
|
|
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2
|
|
|
|
|
end
|
2017-09-01 21:06:51 +00:00
|
|
|
|
end
|
2017-08-22 21:25:18 +00:00
|
|
|
|
|
2017-09-01 21:06:51 +00:00
|
|
|
|
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
|
|
|
|
mt = zeros(p.x)
|
2017-10-12 08:31:38 +00:00
|
|
|
|
vt = zeros(p.x)
|
2017-09-01 21:06:51 +00:00
|
|
|
|
β1p, β2p = β1, β2
|
|
|
|
|
function ()
|
|
|
|
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
2017-10-12 08:31:38 +00:00
|
|
|
|
@. vt = β2 * vt + (1 - β2) * p.Δ^2
|
2017-12-08 18:20:53 +00:00
|
|
|
|
@. p.Δ = mt / (1 - β1p) / (√(vt / (1 - β2p)) + ϵ) * η
|
2017-09-01 21:06:51 +00:00
|
|
|
|
β1p *= β1
|
|
|
|
|
β2p *= β2
|
2017-08-22 21:25:18 +00:00
|
|
|
|
end
|
|
|
|
|
end
|
2017-12-04 08:17:05 +00:00
|
|
|
|
|
|
|
|
|
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
|
|
|
|
mt = zeros(p.x)
|
|
|
|
|
vt = zeros(p.x) .+ ϵ
|
|
|
|
|
v̂t = zeros(p.x) .+ ϵ
|
|
|
|
|
function ()
|
|
|
|
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
|
|
|
|
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
|
|
|
|
|
@. v̂t = max.(v̂t, vt)
|
|
|
|
|
@. p.Δ = η * mt / √v̂t
|
|
|
|
|
end
|
|
|
|
|
end
|
2017-12-08 18:20:53 +00:00
|
|
|
|
|
2017-10-12 08:31:38 +00:00
|
|
|
|
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
|
|
|
|
|
|
|
|
|
|
function expdecay(p::Param, γ::Real)
|
|
|
|
|
if γ != 0
|
|
|
|
|
return () -> p.Δ .+= γ .* p.x
|
|
|
|
|
else
|
|
|
|
|
return () -> nothing
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function invdecay(p::Param, γ::Real)
|
|
|
|
|
if γ != 0
|
|
|
|
|
n = 0
|
|
|
|
|
return () -> begin
|
|
|
|
|
p.Δ .*= 1 / (1 + γ * n)
|
|
|
|
|
n += 1
|
|
|
|
|
end
|
|
|
|
|
else
|
|
|
|
|
return () -> nothing
|
|
|
|
|
end
|
2017-12-08 18:20:53 +00:00
|
|
|
|
end
|