optimisers rework
This commit is contained in:
parent
892a779ed1
commit
387686eb41
@ -2,13 +2,13 @@ __precompile__()
|
|||||||
|
|
||||||
module Flux
|
module Flux
|
||||||
|
|
||||||
|
# Zero Flux Given
|
||||||
|
|
||||||
using Juno
|
using Juno
|
||||||
using Lazy: @forward
|
using Lazy: @forward
|
||||||
|
|
||||||
export Chain, Linear
|
export Chain, Linear
|
||||||
|
|
||||||
# Zero Flux Given
|
|
||||||
|
|
||||||
using NNlib
|
using NNlib
|
||||||
export σ, relu, softmax
|
export σ, relu, softmax
|
||||||
|
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
module Optimise
|
module Optimise
|
||||||
|
|
||||||
export sgd, update!, params, train!
|
export update!, params, train!,
|
||||||
|
SGD
|
||||||
|
|
||||||
include("params.jl")
|
include("params.jl")
|
||||||
include("optimisers.jl")
|
include("optimisers.jl")
|
||||||
|
include("interface.jl")
|
||||||
include("train.jl")
|
include("train.jl")
|
||||||
|
|
||||||
using Flux.Tracker: TrackedArray
|
using Flux.Tracker: TrackedArray
|
||||||
|
12
src/optimise/interface.jl
Normal file
12
src/optimise/interface.jl
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
call(f, xs...) = f(xs...)
|
||||||
|
|
||||||
|
function optimiser(m, fs...)
|
||||||
|
ps = Param.(params(m))
|
||||||
|
fs = map(ps) do p
|
||||||
|
os = map(f -> f(p), fs)
|
||||||
|
() -> foreach(call, os)
|
||||||
|
end
|
||||||
|
() -> foreach(call, fs)
|
||||||
|
end
|
||||||
|
|
||||||
|
SGD(m, η = 1) = optimiser(m, p -> descent(p, 0.1))
|
@ -1,283 +0,0 @@
|
|||||||
export SGD, AdaGrad, RMSProp, AdaDelta, Adam
|
|
||||||
|
|
||||||
struct Optimizer
|
|
||||||
steps
|
|
||||||
end
|
|
||||||
|
|
||||||
function (o::Optimizer)(ps::Vector{Param})
|
|
||||||
states = map(ps) do p
|
|
||||||
p, map(x->x(p), o.steps)
|
|
||||||
end
|
|
||||||
|
|
||||||
() -> for (p, steps) in states
|
|
||||||
foreach(f->f(p), steps)
|
|
||||||
@. p.x -= p.Δx
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function Momentum(η)
|
|
||||||
function (p)
|
|
||||||
momentum = zeros(p.x)
|
|
||||||
|
|
||||||
function (p)
|
|
||||||
@. momentum = η * momentum + p.Δx
|
|
||||||
@. p.Δx = momentum
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function NesterovMomentum(η)
|
|
||||||
function (p)
|
|
||||||
momentum = zeros(p.x)
|
|
||||||
|
|
||||||
function (p)
|
|
||||||
@. momentum = η * momentum + p.Δx
|
|
||||||
@. p.Δx = η * momentum + p.Δx
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function WeightDecayConst(γ)
|
|
||||||
function (p)
|
|
||||||
function (p)
|
|
||||||
# avoid bouncing around 0
|
|
||||||
x = p.x - p.Δx
|
|
||||||
@. p.Δx += (abs(x) <= γ) * x + (abs(x) > γ) * γ * sign(x)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function WeightDecayRatio(γ)
|
|
||||||
function (p)
|
|
||||||
function (p)
|
|
||||||
@. p.Δx += γ * p.x
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function GradDecayFix(lr)
|
|
||||||
function (p)
|
|
||||||
function (p)
|
|
||||||
@. p.Δx *= lr
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function GradDecayExp(γ)
|
|
||||||
function (p)
|
|
||||||
n_iter = 0
|
|
||||||
|
|
||||||
function (p)
|
|
||||||
p.Δx .*= γ ^ n_iter
|
|
||||||
n_iter += 1
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function GradDecayInv(γ)
|
|
||||||
function (p)
|
|
||||||
n_iter = 0
|
|
||||||
|
|
||||||
function (p)
|
|
||||||
p.Δx .*= 1 / (1 + γ * n_iter)
|
|
||||||
n_iter += 1
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function GradClipConst(threshold)
|
|
||||||
function (p)
|
|
||||||
function (p)
|
|
||||||
p.Δx .= max.(min.(p.Δx, threshold), -threshold)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function Accumulate(window)
|
|
||||||
function (p)
|
|
||||||
index = 0
|
|
||||||
acc = zeros(p.x)
|
|
||||||
|
|
||||||
function (p)
|
|
||||||
acc .+= p.Δx
|
|
||||||
|
|
||||||
if index >= window
|
|
||||||
p.Δx .= acc
|
|
||||||
acc .= 0
|
|
||||||
index = 0
|
|
||||||
else
|
|
||||||
p.Δx .= 0
|
|
||||||
index += 1
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function _AdaGrad(ϵ)
|
|
||||||
function (p)
|
|
||||||
acc = zeros(p.x) .+ ϵ
|
|
||||||
|
|
||||||
function (p)
|
|
||||||
@. acc += p.Δx ^ 2
|
|
||||||
@. p.Δx /= √acc
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function _RMSProp(ρ, ϵ)
|
|
||||||
function (p)
|
|
||||||
acc = zeros(p.x) .+ ϵ
|
|
||||||
|
|
||||||
function (p)
|
|
||||||
@. acc = ρ * acc + (1 - ρ) * p.Δx ^ 2
|
|
||||||
@. p.Δx /= √acc
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function _AdaDelta(ρ, ϵ)
|
|
||||||
function (p)
|
|
||||||
acc = zeros(p.x) .+ ϵ
|
|
||||||
Δacc = zeros(p.x) .+ ϵ
|
|
||||||
|
|
||||||
function (p)
|
|
||||||
@. acc = ρ * acc + (1 - ρ) * p.Δx ^ 2
|
|
||||||
@. p.Δx *= √Δacc / √acc
|
|
||||||
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δx ^ 2
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function _Adam(β1, β2, ϵ)
|
|
||||||
function (p)
|
|
||||||
mt = zeros(p.x)
|
|
||||||
vt = zeros(p.x) .+ ϵ
|
|
||||||
β1p = β1
|
|
||||||
β2p = β2
|
|
||||||
|
|
||||||
function (p)
|
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δx
|
|
||||||
@. vt = β2 * vt + (1 - β2) * p.Δx ^ 2
|
|
||||||
|
|
||||||
@. p.Δx = √(1 - β2p) / √(1 - β1p) * mt / √vt
|
|
||||||
|
|
||||||
β1p *= β1
|
|
||||||
β2p *= β2
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
macro restrict_range(var::Symbol, range::String)
|
|
||||||
left, right = split(range, ", ")
|
|
||||||
lo = left[1] == '[' ? :>= : :>
|
|
||||||
lt = left[2:end]
|
|
||||||
ro = right[end] == ']' ? :<= : :<
|
|
||||||
rt = right[1:end-1]
|
|
||||||
|
|
||||||
error_msg = "$var ∈ $range must be hold"
|
|
||||||
var = esc(var)
|
|
||||||
|
|
||||||
quote
|
|
||||||
$( lt != "-∞" && :( $lo($var, $(parse(Float64, lt))) || throw(ArgumentError($error_msg)) ) )
|
|
||||||
$( rt != "∞" && :( $ro($var, $(parse(Float64, rt))) || throw(ArgumentError($error_msg)) ) )
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function SGD(; lr::Real=.1,
|
|
||||||
momentum::Real=0,
|
|
||||||
decay::Real=0,
|
|
||||||
nesterov::Bool=false)
|
|
||||||
|
|
||||||
@restrict_range lr "[0, ∞)"
|
|
||||||
@restrict_range momentum "[0, 1]"
|
|
||||||
@restrict_range decay "[0, ∞)"
|
|
||||||
|
|
||||||
steps = []
|
|
||||||
|
|
||||||
if momentum != 0
|
|
||||||
nesterov ? push!(steps, NesterovMomentum(momentum)) :
|
|
||||||
push!(steps, Momentum(momentum))
|
|
||||||
end
|
|
||||||
|
|
||||||
decay != 0 && push!(steps, GradDecayInv(decay))
|
|
||||||
|
|
||||||
lr != 1 && push!(steps, GradDecayFix(lr))
|
|
||||||
|
|
||||||
Optimizer(steps)
|
|
||||||
end
|
|
||||||
|
|
||||||
function AdaGrad(; lr::Real=.001,
|
|
||||||
epsilon::Real=1e-6,
|
|
||||||
decay::Real=0.)
|
|
||||||
|
|
||||||
@restrict_range lr "[0, ∞)"
|
|
||||||
@restrict_range epsilon "(0, ∞)"
|
|
||||||
@restrict_range decay "[0, ∞)"
|
|
||||||
|
|
||||||
steps = Any[_AdaGrad(epsilon)]
|
|
||||||
|
|
||||||
decay != 0 && push!(steps, GradDecayInv(decay))
|
|
||||||
|
|
||||||
lr != 1 && push!(steps, GradDecayFix(lr))
|
|
||||||
|
|
||||||
Optimizer(steps)
|
|
||||||
end
|
|
||||||
|
|
||||||
function RMSProp(; lr::Real=.001,
|
|
||||||
rho::Real=.9,
|
|
||||||
epsilon::Real=1e-6,
|
|
||||||
decay::Real=0.)
|
|
||||||
|
|
||||||
@restrict_range lr "[0, ∞)"
|
|
||||||
@restrict_range rho "[0, 1]"
|
|
||||||
@restrict_range epsilon "(0, ∞)"
|
|
||||||
@restrict_range decay "[0, ∞)"
|
|
||||||
|
|
||||||
steps = Any[_RMSProp(rho, epsilon)]
|
|
||||||
|
|
||||||
decay != 0 && push!(steps, GradDecayInv(decay))
|
|
||||||
|
|
||||||
lr != 1 && push!(steps, GradDecayFix(lr))
|
|
||||||
|
|
||||||
Optimizer(steps)
|
|
||||||
end
|
|
||||||
|
|
||||||
function AdaDelta(; lr::Real=1.,
|
|
||||||
rho::Real=.9,
|
|
||||||
epsilon::Real=1e-6,
|
|
||||||
decay::Real=0.)
|
|
||||||
|
|
||||||
@restrict_range lr "[0, ∞)"
|
|
||||||
@restrict_range rho "[0, 1]"
|
|
||||||
@restrict_range epsilon "(0, ∞)"
|
|
||||||
@restrict_range decay "[0, ∞)"
|
|
||||||
|
|
||||||
steps = Any[_AdaDelta(rho, epsilon)]
|
|
||||||
|
|
||||||
decay != 0 && push!(steps, GradDecayInv(decay))
|
|
||||||
|
|
||||||
lr != 1 && push!(steps, GradDecayFix(lr))
|
|
||||||
|
|
||||||
Optimizer(steps)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Adam(; lr::Real=.1,
|
|
||||||
beta1::Real=.9,
|
|
||||||
beta2::Real=.999,
|
|
||||||
epsilon::Real=1e-6,
|
|
||||||
decay::Real=0.)
|
|
||||||
|
|
||||||
@restrict_range lr "[0, ∞)"
|
|
||||||
@restrict_range beta1 "[0, 1]"
|
|
||||||
@restrict_range beta2 "[0, 1]"
|
|
||||||
@restrict_range epsilon "(0, ∞)"
|
|
||||||
@restrict_range decay "[0, ∞)"
|
|
||||||
|
|
||||||
steps = Any[_Adam(beta1, beta2, epsilon)]
|
|
||||||
|
|
||||||
decay != 0 && push!(steps, GradDecayInv(decay))
|
|
||||||
|
|
||||||
lr != 1 && push!(steps, GradDecayFix(lr))
|
|
||||||
|
|
||||||
Optimizer(steps)
|
|
||||||
end
|
|
@ -1,13 +1,71 @@
|
|||||||
struct SGD
|
function descent(p::Param, η::Real)
|
||||||
ps::Vector{Param}
|
() -> p.x .-= p.Δ .* η
|
||||||
η::Float32
|
|
||||||
end
|
end
|
||||||
|
|
||||||
sgd(m, η) = SGD(params(m), η)
|
function momentum(p::Param, ρ::Real)
|
||||||
|
mo = zeros(p.x)
|
||||||
|
() -> p.Δ .= mo .= ρ .* mo .+ p.Δ
|
||||||
|
end
|
||||||
|
|
||||||
function update!(o::SGD)
|
function nesterov(p::Param, ρ::Real)
|
||||||
for p in o.ps
|
mo = zeros(p.x)
|
||||||
p.x .-= p.Δ .* o.η
|
function ()
|
||||||
Δ .= 0
|
mo .= ρ .* mo .+ p.Δ
|
||||||
|
p.Δ .= ρ .* mo .+ p.Δ
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function clip(p::Param, thresh::Real)
|
||||||
|
() -> clamp!(p.Δ, -thresh, thresh)
|
||||||
|
end
|
||||||
|
|
||||||
|
function weightdecay(p::Param, γ::Real)
|
||||||
|
() -> p.Δ .+= γ .* p.x
|
||||||
|
end
|
||||||
|
|
||||||
|
function invdecay(p::Param, γ::Real)
|
||||||
|
n = 0
|
||||||
|
function ()
|
||||||
|
p.Δ .*= 1 / (1 + γ * n)
|
||||||
|
n += 1
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
||||||
|
acc = zeros(p.x) .+ ϵ
|
||||||
|
function ()
|
||||||
|
@. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2
|
||||||
|
@. p.Δ /= √acc * η
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
||||||
|
acc = zeros(p.x) .+ ϵ
|
||||||
|
function ()
|
||||||
|
@. acc += p.Δ ^ 2
|
||||||
|
@. p.Δ /= √acc * η
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function adadelta(p::Param; ρ::Real = 0.95, ϵ::Real = 1e-8)
|
||||||
|
acc = zeros(p.x) .+ ϵ
|
||||||
|
Δacc = zeros(p.x) .+ ϵ
|
||||||
|
function ()
|
||||||
|
@. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2
|
||||||
|
@. p.Δ *= √Δacc / √acc
|
||||||
|
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ ^ 2
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||||
|
mt = zeros(p.x)
|
||||||
|
vt = zeros(p.x) .+ ϵ
|
||||||
|
β1p, β2p = β1, β2
|
||||||
|
function ()
|
||||||
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||||
|
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
|
||||||
|
@. p.Δ = √(1 - β2p) / √(1 - β1p) * mt / √vt * η
|
||||||
|
β1p *= β1
|
||||||
|
β2p *= β2
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -7,7 +7,7 @@ function train!(m, data, opt; epoch = 1)
|
|||||||
loss = m(x, y)
|
loss = m(x, y)
|
||||||
@show loss
|
@show loss
|
||||||
back!(loss)
|
back!(loss)
|
||||||
update!(opt)
|
opt()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user