Compare commits

...

1 Commits

Author SHA1 Message Date
Mike J Innes
f1ec0c6f72 initial sketch 2018-05-31 20:29:59 +01:00
6 changed files with 132 additions and 196 deletions

View File

@ -22,7 +22,7 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl") include("optimise/Optimise.jl")
using .Optimise using .Optimise
using .Optimise: @epochs using .Optimise: @epochs
export SGD, ADAM, AdaMax, Momentum, Nesterov, export SGD, Descent, ADAM, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad RMSProp, ADAGrad, ADADelta, AMSGrad
include("utils.jl") include("utils.jl")

View File

@ -1,21 +1,9 @@
module Optimise module Optimise
export train!, export train!,
SGD, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad SGD, Descent, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
struct Param{T}
x::T
Δ::T
end
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
include("optimisers.jl") include("optimisers.jl")
include("interface.jl")
include("train.jl") include("train.jl")
using Flux.Tracker: TrackedArray
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
end end

View File

@ -1,93 +0,0 @@
call(f, xs...) = f(xs...)
# note for optimisers: set to zero
# p.Δ at the end of the weigths update
function optimiser(ps, fs...)
ps = [Param(p) for p in ps]
fs = map(ps) do p
os = map(f -> f(p), fs)
() -> foreach(call, os)
end
() -> foreach(call, fs)
end
"""
SGD(params, η = 0.1; decay = 0)
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
Supports inverse decaying learning rate if the `decay` argument is provided.
"""
SGD(ps, η = 0.1; decay = 0) =
optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η))
"""
Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay.
"""
Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1))
"""
Nesterov(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay.
"""
Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1))
"""
RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0)
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
optimiser. Parameters other than learning rate don't need tuning. Often a good
choice for recurrent networks.
"""
RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
"""
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
the -norm.
"""
AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning.
"""
ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning.
"""
ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1))
"""
AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
tuning.
"""
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))

View File

@ -1,109 +1,139 @@
function descent(p::Param, η::Real) using Flux
function () using Base: @get!
@. p.x -= η * p.Δ
@. p.Δ = 0 const ϵ = 1e-8
end
# TODO: should use weak refs
"""
Descent(η)
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
"""
mutable struct Descent
eta::Float64
end end
function momentum(p::Param, ρ, η) function update!(o::Descent, x, Δ)
v = zeros(p.x) Δ .*= o.eta
function ()
@. v = ρ * v - η * p.Δ
@. p.Δ = -v
end
end end
# Ref. https://arxiv.org/pdf/1212.0901.pdf """
function nesterov(p::Param, ρ, η) Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
v = zeros(p.x)
function () Gradient descent with learning rate `η` and momentum `ρ`.
d = @. ρ^2 * v - (1+ρ) * η * p.Δ """
@. v = ρ*v - η*p.Δ mutable struct Momentum
@. p.Δ = -d eta::Float64
end rho::Float64
velocity::ObjectIdDict
end end
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) Momentum(η, ρ = 0.9) = Momentum(η, ρ, ObjectIdDict())
acc = zeros(p.x)
function () function update!(o::Momentum, x, Δ)
@. acc = ρ * acc + (1 - ρ) * p.Δ^2 η, ρ = o.eta, o.rho
@. p.Δ *= η / (acc + ϵ) v = @get!(o.velocity, x, zero(x))::typeof(x)
end @. v = ρ * v - η * Δ
@. Δ = -v
end end
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) """
acc = zeros(p.x) .+ ϵ Nesterov(eta, ρ = 0.9)
function ()
@. acc += p.Δ^2 Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
@. p.Δ *= η / acc """
end mutable struct Nesterov
eta::Float64
rho::Float64
velocity::ObjectIdDict
end end
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8) Nesterov(η, ρ = 0.9) = Nesterov(η, ρ, ObjectIdDict())
acc = zeros(p.x)
Δacc = zeros(p.x) function update!(o::Nesterov, x, Δ)
function () η, ρ = o.eta, o.rho
@. acc = ρ * acc + (1 - ρ) * p.Δ^2 v = @get!(o.velocity, x, zero(x))::typeof(x)
@. p.Δ *= (Δacc + ϵ) / (acc + ϵ) d = @. ρ^2 * v - (1+ρ) * η * Δ
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2 @. v = ρ*v - η*Δ
end @. Δ = -d
end end
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) """
mt = zeros(p.x) RMSProp(η = 0.001, ρ = 0.9)
vt = zeros(p.x)
β1p, β2p = β1, β2 [RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
function () optimiser. Parameters other than learning rate don't need tuning. Often a good
@. mt = β1 * mt + (1 - β1) * p.Δ choice for recurrent networks.
@. vt = β2 * vt + (1 - β2) * p.Δ^2 """
@. p.Δ = mt / (1 - β1p) / ((vt / (1 - β2p)) + ϵ) * η mutable struct RMSProp
β1p *= β1 eta::Float64
β2p *= β2 rho::Float64
end acc::ObjectIdDict
end end
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, ObjectIdDict())
mt = zeros(p.x)
ut = zeros(p.x) function update!(o::RMSProp, x, Δ)
β1p = β1 η, ρ = o.eta, o.rho
function () acc = @get!(o.acc, x, zero(x))::typeof(x)
@. mt = β1 * mt + (1 - β1) * p.Δ @. acc = ρ * acc + (1 - ρ) * Δ^2
@. ut = max(β2 * ut, abs(p.Δ)) @. Δ *= η / (acc + ϵ)
@. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ)
β1p *= β1
end
end end
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) """
mt = zeros(p.x) ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
vt = zeros(p.x) .+ ϵ
v̂t = zeros(p.x) .+ ϵ [ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
function () """
@. mt = β1 * mt + (1 - β1) * p.Δ mutable struct ADAM
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 eta::Float64
@. v̂t = max.(v̂t, vt) beta::Tuple{Float64,Float64}
@. p.Δ = η * mt / v̂t state::ObjectIdDict
end
end end
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh) ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, ObjectIdDict())
function expdecay(p::Param, γ::Real) function update!(o::ADAM, x, Δ)
if γ != 0 η, β = o.eta, o.beta
return () -> p.Δ .+= γ .* p.x mt, vt, βp = @get!(o.state, x, (zero(x), zero(x), β))
else @. mt = β[1] * mt + (1 - β[1]) * Δ
return () -> nothing @. vt = β[2] * vt + (1 - β[2]) * Δ^2
end @. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η
o.state[x] = (mt, vt, βp .* β)
end end
function invdecay(p::Param, γ::Real) # """
if γ != 0 # AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
n = 0 #
return () -> begin # [AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
p.Δ .*= 1 / (1 + γ * n) # the ∞-norm.
n += 1 # """
end
else # """
return () -> nothing # ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
end #
end # [ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
# Parameters don't need tuning.
# """
# """
# ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
#
# [ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
# tuning.
# """
# """
# AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
#
# [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
# tuning.
# """
# struct Optimiser
# os::Vector{Any}
# end
# TODO: decay

View File

@ -1,6 +1,16 @@
using Juno using Juno
using Flux.Tracker: back! using Flux.Tracker: data, grad, back!
function update!(opt, xs)
for x in xs
x, Δ = data(x), grad(x)
update!(opt, x, Δ)
x .-= Δ
Δ .= 0
end
end
# Callback niceties
runall(f) = f runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs) runall(fs::AbstractVector) = () -> foreach(call, fs)

View File

@ -10,6 +10,7 @@ istracked(x) = tracker(x) ≠ nothing
isleaf(x) = !istracked(x) || isleaf(tracker(x)) isleaf(x) = !istracked(x) || isleaf(tracker(x))
data(x) = istracked(x) ? data(tracker(x)) : x data(x) = istracked(x) ? data(tracker(x)) : x
grad(x) = grad(tracker(x)) grad(x) = grad(tracker(x))
grad(::Void) = nothing
struct Call{F,As<:Tuple} struct Call{F,As<:Tuple}
func::F func::F