added remaining optimizers and tests

This commit is contained in:
Dhairya Gandhi 2018-09-16 17:34:51 +05:30
parent 63bc71698b
commit 6665189ff1
5 changed files with 174 additions and 39 deletions

View File

@ -19,8 +19,9 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl") include("optimise/Optimise.jl")
using .Optimise using .Optimise
using .Optimise: @epochs using .Optimise: @epochs
export Descent, ADAM, Momentum, Nesterov, export Descent, ADAM, Momentum, Nesterov, RMSProp,
RMSProp, update! ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
InvDecay, ExpDecay
include("utils.jl") include("utils.jl")
include("onehot.jl") include("onehot.jl")

View File

@ -1,7 +1,9 @@
module Optimise module Optimise
export train!, export train!,
Descent, ADAM, Momentum, Nesterov, RMSProp, stop, StopException Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
InvDecay, ExpDecay, stop, StopException, Compose
include("optimisers.jl") include("optimisers.jl")
include("train.jl") include("train.jl")

View File

@ -20,7 +20,7 @@ function update!(o::Descent, x, Δ)
end end
""" """
Momentum(params, η = 0.01; ρ = 0.9, decay = 0) Momentum(params, η = 0.01; ρ = 0.9)
Gradient descent with learning rate `η` and momentum `ρ`. Gradient descent with learning rate `η` and momentum `ρ`.
""" """
@ -83,7 +83,7 @@ function update!(o::RMSProp, x, Δ)
end end
""" """
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) ADAM(η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser. [ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
""" """
@ -105,36 +105,154 @@ function update!(o::ADAM, x, Δ)
return Δ return Δ
end end
# """ """
# AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)
#
# [AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
# the ∞-norm.
# """
# """ [AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
# ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0) the -norm.
# """
# [ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. mutable struct AdaMax
# Parameters don't need tuning. eta::Float64
# """ beta::Tuple{Float64,Float64}
state::IdDict
end
# """ AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
# ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
#
# [ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
# tuning.
# """
# """ function update!(o::AdaMax, x, Δ)
# AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) η, β = o.eta, o.beta
# mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
# [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need @. mt = β[1] * mt + (1 - β[1]) * Δ
# tuning. @. ut = max(β[2] * ut, abs(Δ))
# """ @. Δ = (η/(1 - βp[1])) * mt/(ut + ϵ)
o.state[x] = (mt, ut, βp .* β)
return Δ
end
# struct Optimiser """
# os::Vector{Any} ADAGrad(η = 0.1; ϵ = 1e-8)
# end
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning.
"""
mutable struct ADAGrad
eta::Float64
acc::IdDict
end
ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
function update!(o::ADAGrad, x, Δ)
η = o.eta
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
@. acc += Δ^2
@. Δ *= η / (acc + ϵ)
end
"""
ADADelta(params; ρ = 0.9, ϵ = 1e-8)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning.
"""
mutable struct ADADelta
rho::Float64
state::IdDict
end
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
function update!(o::ADADelta, x, Δ)
ρ = o.rho
acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
@. acc = ρ * acc + (1 - ρ) * Δ^2
@. Δ *= (Δacc + ϵ) / (acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * Δ^2
return Δ
end
"""
AMSGrad(η = 0.001, β = (0.9, 0.999))
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
tuning.
"""
mutable struct AMSGrad
eta::Float64
beta::Tuple{Float64, Float64}
state::IdDict
end
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
function update!(o::AMSGrad, x, Δ)
η, β = o.eta, o.beta
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2
@. v̂t = max.(v̂t, vt)
@. Δ = η * mt / v̂t
end
"""
NADAM(η = 0.001, β = (0.9, 0.999))
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser. Parameters don't need
tuning.
"""
mutable struct NADAM
eta::Float64
beta::Tuple{Float64, Float64}
state::IdDict
end
NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
function update!(o::NADAM, x, Δ)
η, β = o.eta, o.beta
β1p, β2p = o.beta
mt, vt = get!(o.state, x, (zero(x), zero(x)))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / (vt * β[2] / (1 - β2p) + ϵ) * η
o.state[x] = (mt, vt, (β1p * β[1], β2p * β[2]))
return Δ
end
mutable struct Compose
os::Vector{Any}
end
function update!(o::Compose, x, Δ)
for opt in o.os
Δ = update!(opt, x, Δ)
end
return Δ
end
# TODO: decay # TODO: decay
mutable struct InvDecay
gamma::Float64
n::Int64
end
InvDecay(γ = 0.001, n = 0) = InvDecay(γ, n)
function update!(o::InvDecay, x, Δ)
γ, n = o.gamma, o.n
Δ .*= 1 / (1 + γ * n)
o.n += 1
return Δ
end
mutable struct ExpDecay
gamma::Float64
end
ExpDecay(γ = 0.001) = ExpDecay(γ)
function update!(o::ExpDecay, x, Δ)
γ = o.gamma
@. Δ += γ * x
end

View File

@ -361,7 +361,7 @@ end
track(Call(back, tracker.(args)), y) track(Call(back, tracker.(args)), y)
end end
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted, cat_nested using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
struct TrackedStyle <: BroadcastStyle end struct TrackedStyle <: BroadcastStyle end
@ -385,10 +385,6 @@ end
using Requires using Requires
Base.Broadcast.cat_nested(t::Base.Broadcast.Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...)
Base.Broadcast.cat_nested(t::Any, rest...) = (t, cat_nested(rest...)...)
Base.Broadcast.cat_nested() = ()
# https://github.com/FluxML/Flux.jl/issues/353 # https://github.com/FluxML/Flux.jl/issues/353
@init Requires.isprecompiling() || @eval Base.Broadcast begin @init Requires.isprecompiling() || @eval Base.Broadcast begin
function flatten(bc::Broadcasted{Style}) where {Style} function flatten(bc::Broadcasted{Style}) where {Style}

View File

@ -3,13 +3,16 @@ using Flux.Tracker
using Test using Test
@testset "Optimise" begin @testset "Optimise" begin
w = randn(10, 10) w = randn(10, 10)
@testset for Opt in [Descent, ADAM, Nesterov, RMSProp, Momentum] @testset for Opt in [ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
w = param(randn(10, 10)) w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)
opt = Opt(0.001) opt = Opt(0.001)
if opt isa Descent if opt isa Descent || opt isa ADAGrad
opt = Opt(0.1) opt = Opt(0.1)
end end
if opt isa ADADelta
opt = Opt(0.9)
end
for t = 1: 10^5 for t = 1: 10^5
l = loss(rand(10)) l = loss(rand(10))
back!(l) back!(l)
@ -20,6 +23,21 @@ using Test
end end
end end
@testset "Compose" begin
w = randn(10, 10)
@testset for Opt in [InvDecay, ExpDecay]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Compose(vec([Opt(), ADAM(0.001)]))
for t = 1:10^5
l = loss(rand(10))
back!(l)
delta = Optimise.update!(opt, w.data, w.grad)
w.data .-= delta
end
@test Flux.mse(w, w) < 0.01
end
@testset "Training Loop" begin @testset "Training Loop" begin
i = 0 i = 0
l = param(1) l = param(1)