added remaining optimizers and tests
This commit is contained in:
parent
63bc71698b
commit
6665189ff1
|
@ -19,8 +19,9 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
|||
include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
using .Optimise: @epochs
|
||||
export Descent, ADAM, Momentum, Nesterov,
|
||||
RMSProp, update!
|
||||
export Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
|
||||
InvDecay, ExpDecay
|
||||
|
||||
include("utils.jl")
|
||||
include("onehot.jl")
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
module Optimise
|
||||
|
||||
export train!,
|
||||
Descent, ADAM, Momentum, Nesterov, RMSProp, stop, StopException
|
||||
Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
|
||||
InvDecay, ExpDecay, stop, StopException, Compose
|
||||
|
||||
include("optimisers.jl")
|
||||
include("train.jl")
|
||||
|
|
|
@ -20,7 +20,7 @@ function update!(o::Descent, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
|
||||
Momentum(params, η = 0.01; ρ = 0.9)
|
||||
|
||||
Gradient descent with learning rate `η` and momentum `ρ`.
|
||||
"""
|
||||
|
@ -83,7 +83,7 @@ function update!(o::RMSProp, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
ADAM(η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)
|
||||
|
||||
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
||||
"""
|
||||
|
@ -105,36 +105,154 @@ function update!(o::ADAM, x, Δ)
|
|||
return Δ
|
||||
end
|
||||
|
||||
# """
|
||||
# AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
#
|
||||
# [AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
|
||||
# the ∞-norm.
|
||||
# """
|
||||
"""
|
||||
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)
|
||||
|
||||
# """
|
||||
# ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
|
||||
#
|
||||
# [ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
|
||||
# Parameters don't need tuning.
|
||||
# """
|
||||
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
|
||||
the ∞-norm.
|
||||
"""
|
||||
mutable struct AdaMax
|
||||
eta::Float64
|
||||
beta::Tuple{Float64,Float64}
|
||||
state::IdDict
|
||||
end
|
||||
|
||||
# """
|
||||
# ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
|
||||
#
|
||||
# [ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
|
||||
# tuning.
|
||||
# """
|
||||
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
|
||||
|
||||
# """
|
||||
# AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
#
|
||||
# [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
|
||||
# tuning.
|
||||
# """
|
||||
function update!(o::AdaMax, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
@. ut = max(β[2] * ut, abs(Δ))
|
||||
@. Δ = (η/(1 - βp[1])) * mt/(ut + ϵ)
|
||||
o.state[x] = (mt, ut, βp .* β)
|
||||
return Δ
|
||||
end
|
||||
|
||||
# struct Optimiser
|
||||
# os::Vector{Any}
|
||||
# end
|
||||
"""
|
||||
ADAGrad(η = 0.1; ϵ = 1e-8)
|
||||
|
||||
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
|
||||
Parameters don't need tuning.
|
||||
"""
|
||||
mutable struct ADAGrad
|
||||
eta::Float64
|
||||
acc::IdDict
|
||||
end
|
||||
|
||||
ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
||||
|
||||
function update!(o::ADAGrad, x, Δ)
|
||||
η = o.eta
|
||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
|
||||
@. acc += Δ^2
|
||||
@. Δ *= η / √(acc + ϵ)
|
||||
end
|
||||
|
||||
"""
|
||||
ADADelta(params; ρ = 0.9, ϵ = 1e-8)
|
||||
|
||||
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
|
||||
tuning.
|
||||
"""
|
||||
mutable struct ADADelta
|
||||
rho::Float64
|
||||
state::IdDict
|
||||
end
|
||||
|
||||
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
|
||||
|
||||
function update!(o::ADADelta, x, Δ)
|
||||
ρ = o.rho
|
||||
acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
|
||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||
@. Δ *= √(Δacc + ϵ) / √(acc + ϵ)
|
||||
@. Δacc = ρ * Δacc + (1 - ρ) * Δ^2
|
||||
return Δ
|
||||
end
|
||||
|
||||
"""
|
||||
AMSGrad(η = 0.001, β = (0.9, 0.999))
|
||||
|
||||
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
|
||||
tuning.
|
||||
"""
|
||||
mutable struct AMSGrad
|
||||
eta::Float64
|
||||
beta::Tuple{Float64, Float64}
|
||||
state::IdDict
|
||||
end
|
||||
|
||||
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
|
||||
|
||||
function update!(o::AMSGrad, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
@. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2
|
||||
@. v̂t = max.(v̂t, vt)
|
||||
@. Δ = η * mt / √v̂t
|
||||
end
|
||||
|
||||
"""
|
||||
NADAM(η = 0.001, β = (0.9, 0.999))
|
||||
|
||||
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser. Parameters don't need
|
||||
tuning.
|
||||
"""
|
||||
mutable struct NADAM
|
||||
eta::Float64
|
||||
beta::Tuple{Float64, Float64}
|
||||
state::IdDict
|
||||
end
|
||||
|
||||
NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
|
||||
|
||||
function update!(o::NADAM, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
β1p, β2p = o.beta
|
||||
mt, vt = get!(o.state, x, (zero(x), zero(x)))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
||||
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / √(vt * β[2] / (1 - β2p) + ϵ) * η
|
||||
o.state[x] = (mt, vt, (β1p * β[1], β2p * β[2]))
|
||||
return Δ
|
||||
end
|
||||
|
||||
mutable struct Compose
|
||||
os::Vector{Any}
|
||||
end
|
||||
|
||||
function update!(o::Compose, x, Δ)
|
||||
for opt in o.os
|
||||
Δ = update!(opt, x, Δ)
|
||||
end
|
||||
return Δ
|
||||
end
|
||||
|
||||
# TODO: decay
|
||||
|
||||
mutable struct InvDecay
|
||||
gamma::Float64
|
||||
n::Int64
|
||||
end
|
||||
|
||||
InvDecay(γ = 0.001, n = 0) = InvDecay(γ, n)
|
||||
|
||||
function update!(o::InvDecay, x, Δ)
|
||||
γ, n = o.gamma, o.n
|
||||
Δ .*= 1 / (1 + γ * n)
|
||||
o.n += 1
|
||||
return Δ
|
||||
end
|
||||
|
||||
mutable struct ExpDecay
|
||||
gamma::Float64
|
||||
end
|
||||
|
||||
ExpDecay(γ = 0.001) = ExpDecay(γ)
|
||||
|
||||
function update!(o::ExpDecay, x, Δ)
|
||||
γ = o.gamma
|
||||
@. Δ += γ * x
|
||||
end
|
||||
|
|
|
@ -361,7 +361,7 @@ end
|
|||
track(Call(back, tracker.(args)), y)
|
||||
end
|
||||
|
||||
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted, cat_nested
|
||||
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
|
||||
|
||||
struct TrackedStyle <: BroadcastStyle end
|
||||
|
||||
|
@ -385,10 +385,6 @@ end
|
|||
|
||||
using Requires
|
||||
|
||||
Base.Broadcast.cat_nested(t::Base.Broadcast.Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...)
|
||||
Base.Broadcast.cat_nested(t::Any, rest...) = (t, cat_nested(rest...)...)
|
||||
Base.Broadcast.cat_nested() = ()
|
||||
|
||||
# https://github.com/FluxML/Flux.jl/issues/353
|
||||
@init Requires.isprecompiling() || @eval Base.Broadcast begin
|
||||
function flatten(bc::Broadcasted{Style}) where {Style}
|
||||
|
|
|
@ -3,13 +3,16 @@ using Flux.Tracker
|
|||
using Test
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [Descent, ADAM, Nesterov, RMSProp, Momentum]
|
||||
@testset for Opt in [ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Opt(0.001)
|
||||
if opt isa Descent
|
||||
if opt isa Descent || opt isa ADAGrad
|
||||
opt = Opt(0.1)
|
||||
end
|
||||
if opt isa ADADelta
|
||||
opt = Opt(0.9)
|
||||
end
|
||||
for t = 1: 10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
|
@ -20,6 +23,21 @@ using Test
|
|||
end
|
||||
end
|
||||
|
||||
@testset "Compose" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [InvDecay, ExpDecay]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Compose(vec([Opt(), ADAM(0.001)]))
|
||||
for t = 1:10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
delta = Optimise.update!(opt, w′.data, w′.grad)
|
||||
w′.data .-= delta
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
end
|
||||
|
||||
@testset "Training Loop" begin
|
||||
i = 0
|
||||
l = param(1)
|
||||
|
|
Loading…
Reference in New Issue