Merge pull request #379 from dhairyagandhi96/master

New optimisers interface
This commit is contained in:
Mike J Innes 2018-10-31 16:38:40 +00:00 committed by GitHub
commit 43c5f90d93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 490 additions and 265 deletions

View File

@ -4,7 +4,6 @@ os:
- linux - linux
# - osx # - osx
julia: julia:
- 0.7
- 1.0 - 1.0
- nightly - nightly
# uncomment the following lines to override the default test script # uncomment the following lines to override the default test script

View File

@ -1,4 +1,4 @@
julia 0.7 julia 1.0
Juno Juno
MacroTools 0.3.3 MacroTools 0.3.3
NNlib NNlib

View File

@ -19,8 +19,9 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl") include("optimise/Optimise.jl")
using .Optimise using .Optimise
using .Optimise: @epochs using .Optimise: @epochs
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, InvDecay, ExpDecay, WeightDecay
include("utils.jl") include("utils.jl")
include("onehot.jl") include("onehot.jl")

View File

@ -1,23 +1,12 @@
module Optimise module Optimise
export train!, export train!,
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
struct Param{T}
x::T
Δ::T
end
Param(x::AbstractArray) = Param(x, zero(x))
include("optimisers.jl") include("optimisers.jl")
include("interface.jl")
include("train.jl") include("train.jl")
include("deprecations.jl")
using Flux.Tracker: TrackedArray
Param(x::TrackedArray) = Param(x.data, x.grad)
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
end end

View File

@ -0,0 +1,125 @@
using Base: depwarn
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
# legacy update rule
updaterule(opt, ps) = () -> update!(p, ps)
function SGD(params::AbstractArray, η = 0.1; decay = 0.)
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
ps = params
opt = Descent(η)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.)
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
ps = params
opt = Momentum(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
ps = params
opt = Nesterov(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
ps = params
opt = RMSProp(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
ps = params
β = (β1, β2)
opt = ADAM(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.)
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
ps = params
opt = ADAGrad(η)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.)
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
ps = params
opt = ADADelta(ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
ps = params
β = (β1, β2)
opt = AdaMax(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
ps = params
β = (β1, β2)
opt = AMSGrad(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
ps = params
β = (β1, β2)
opt = NADAM(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
ps = params
β = (β1, β2)
opt = ADAMW(η, β)
opt = check_decay(opt, decay)
decay != 0 && (opt = Optimiser(opt, WeightDecay(decay)))
updaterule(opt, ps)
end
# Old training loop
struct OldOptimiser
func
end
update!(opt::OldOptimiser, ps) = opt.func()
# Train function
function train!(loss, data, opt; cb = () -> ())
depwarn("train!(loss, data, opt) is deprecated; use train!(loss, params, data, opt) instead", :train!)
train!(loss, (), data, OldOptimiser(opt); cb = cb)
end

View File

@ -1,110 +0,0 @@
call(f, xs...) = f(xs...)
# note for optimisers: set to zero
# p.Δ at the end of the weights update
function optimiser(ps, fs...)
ps = [Param(p) for p in ps]
fs = map(ps) do p
os = map(f -> f(p), fs)
() -> foreach(call, os)
end
() -> foreach(call, fs)
end
"""
SGD(params, η = 0.1; decay = 0)
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
Supports inverse decaying learning rate if the `decay` argument is provided.
"""
SGD(ps, η = 0.1; decay = 0) =
optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η))
"""
Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay.
"""
Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1))
"""
Nesterov(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay.
"""
Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1))
"""
RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0)
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
optimiser. Parameters other than learning rate don't need tuning. Often a good
choice for recurrent networks.
"""
RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
"""
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
"""
ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay))
"""
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
the -norm.
"""
AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning.
"""
ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning.
"""
ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1))
"""
AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
tuning.
"""
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
"""
NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
than learning rate don't need tuning.
"""
NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))

View File

@ -1,130 +1,304 @@
function descent(p::Param, η::Real) using Flux
function () using Base: @get!
@. p.x -= η * p.Δ using MacroTools: @forward
@. p.Δ = 0
end const ϵ = 1e-8
# TODO: should use weak refs
"""
Descent(η)
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
"""
mutable struct Descent
eta::Float64
end end
# Ref: https://arxiv.org/abs/1711.05101.pdf Descent() = Descent(0.1)
function descentweightdecay(p::Param, η::Real, γ::Real)
function () function update!(o::Descent, x, Δ)
@. p.x = p.x - η * (p.Δ + γ * p.x) Δ .*= o.eta
@. p.Δ = 0
end
end end
function momentum(p::Param, ρ, η) """
v = zero(p.x) Momentum(params, η = 0.01; ρ = 0.9)
function ()
@. v = ρ * v - η * p.Δ Gradient descent with learning rate `η` and momentum `ρ`.
@. p.Δ = -v """
end mutable struct Momentum
eta::Float64
rho::Float64
velocity::IdDict
end end
# Ref. https://arxiv.org/pdf/1212.0901.pdf Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
function nesterov(p::Param, ρ, η)
v = zero(p.x) function update!(o::Momentum, x, Δ)
function () η, ρ = o.eta, o.rho
d = @. ρ^2 * v - (1+ρ) * η * p.Δ v = get!(o.velocity, x, zero(x))::typeof(x)
@. v = ρ*v - η*p.Δ @. v = ρ * v - η * Δ
@. p.Δ = -d @. Δ = -v
end
end end
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) """
acc = zero(p.x) Nesterov(eta, ρ = 0.9)
function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ^2 Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
@. p.Δ *= η / (acc + ϵ) """
end mutable struct Nesterov
eta::Float64
rho::Float64
velocity::IdDict
end end
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
acc = zero(p.x) .+ ϵ
function () function update!(o::Nesterov, x, Δ)
@. acc += p.Δ^2 η, ρ = o.eta, o.rho
@. p.Δ *= η / (acc + ϵ) v = get!(o.velocity, x, zero(x))::typeof(x)
end d = @. ρ^2 * v - (1+ρ) * η * Δ
@. v = ρ*v - η*Δ
@. Δ = -d
end end
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8) """
acc = zero(p.x) RMSProp(η = 0.001, ρ = 0.9)
Δacc = zero(p.x)
function () [RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
@. acc = ρ * acc + (1 - ρ) * p.Δ^2 optimiser. Parameters other than learning rate don't need tuning. Often a good
@. p.Δ *= (Δacc + ϵ) / (acc + ϵ) choice for recurrent networks.
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2 """
end mutable struct RMSProp
eta::Float64
rho::Float64
acc::IdDict
end end
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
mt = zero(p.x)
vt = zero(p.x) function update!(o::RMSProp, x, Δ)
β1p, β2p = β1, β2 η, ρ = o.eta, o.rho
function () acc = get!(o.acc, x, zero(x))::typeof(x)
@. mt = β1 * mt + (1 - β1) * p.Δ @. acc = ρ * acc + (1 - ρ) * Δ^2
@. vt = β2 * vt + (1 - β2) * p.Δ^2 @. Δ *= η / (acc + ϵ)
@. p.Δ = mt / (1 - β1p) / (vt / (1 - β2p) + ϵ) * η
β1p *= β1
β2p *= β2
end
end end
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) """
mt = zero(p.x) ADAM(η = 0.001, β = (0.9, 0.999))
ut = zero(p.x)
β1p = β1 [ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
function () """
@. mt = β1 * mt + (1 - β1) * p.Δ mutable struct ADAM
@. ut = max(β2 * ut, abs(p.Δ)) eta::Float64
@. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ) beta::Tuple{Float64,Float64}
β1p *= β1 state::IdDict
end
end end
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
mt = zero(p.x)
vt = zero(p.x) .+ ϵ function update!(o::ADAM, x, Δ)
v̂t = zero(p.x) .+ ϵ η, β = o.eta, o.beta
function () mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β1 * mt + (1 - β1) * p.Δ @. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 @. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η
o.state[x] = (mt, vt, βp .* β)
return Δ
end
"""
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
the -norm.
"""
mutable struct AdaMax
eta::Float64
beta::Tuple{Float64,Float64}
state::IdDict
end
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
function update!(o::AdaMax, x, Δ)
η, β = o.eta, o.beta
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. ut = max(β[2] * ut, abs(Δ))
@. Δ = (η/(1 - βp[1])) * mt/(ut + ϵ)
o.state[x] = (mt, ut, βp .* β)
return Δ
end
"""
ADAGrad(η = 0.1; ϵ = 1e-8)
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning.
"""
mutable struct ADAGrad
eta::Float64
acc::IdDict
end
ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
function update!(o::ADAGrad, x, Δ)
η = o.eta
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
@. acc += Δ^2
@. Δ *= η / (acc + ϵ)
end
"""
ADADelta(ρ = 0.9, ϵ = 1e-8)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning.
"""
mutable struct ADADelta
rho::Float64
state::IdDict
end
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
function update!(o::ADADelta, x, Δ)
ρ = o.rho
acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
@. acc = ρ * acc + (1 - ρ) * Δ^2
@. Δ *= (Δacc + ϵ) / (acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * Δ^2
return Δ
end
"""
AMSGrad(η = 0.001, β = (0.9, 0.999))
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
tuning.
"""
mutable struct AMSGrad
eta::Float64
beta::Tuple{Float64, Float64}
state::IdDict
end
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
function update!(o::AMSGrad, x, Δ)
η, β = o.eta, o.beta
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2
@. v̂t = max.(v̂t, vt) @. v̂t = max.(v̂t, vt)
@. p.Δ = η * mt / v̂t @. Δ = η * mt / v̂t
end
end end
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) """
mt = zero(p.x) NADAM(η = 0.001, β = (0.9, 0.999))
vt = zero(p.x)
β1p, β2p = β1, β2 [NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser. Parameters don't need
function () tuning.
@. mt = β1 * mt + (1 - β1) * p.Δ """
@. vt = β2 * vt + (1 - β2) * p.Δ^2 mutable struct NADAM
@. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / (vt * β2 / (1 - β2p) + ϵ) * η eta::Float64
β1p *= β1 beta::Tuple{Float64, Float64}
β2p *= β2 state::IdDict
end
end end
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh) NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
function expdecay(p::Param, γ::Real) function update!(o::NADAM, x, Δ)
if γ != 0 η, β = o.eta, o.beta
return () -> p.Δ .+= γ .* p.x β1p, β2p = o.beta
else mt, vt = get!(o.state, x, (zero(x), zero(x)))
return () -> nothing @. mt = β[1] * mt + (1 - β[1]) * Δ
end @. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / (vt * β[2] / (1 - β2p) + ϵ) * η
o.state[x] = (mt, vt, (β1p * β[1], β2p * β[2]))
return Δ
end end
function invdecay(p::Param, γ::Real) """
if γ != 0 ADAMW((η = 0.001, β = (0.9, 0.999), decay = 0)
n = 0
return () -> begin [ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
p.Δ .*= 1 / (1 + γ * n) """
n += 1 ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
end Optimiser(ADAM(η, β), WeightDecay(wd))
else
return () -> nothing # Compose optimizers
end
"""
Optimiser(a, b, c...)
Combine several optimisers into one; each optimiser produces a modified gradient
that will be fed into the next, and this is finally applied to the parameter as
usual.
"""
mutable struct Optimiser
os::Vector{Any}
end
Optimiser(o...) = Optimiser(Any[o...])
@forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
@forward Optimiser.os Base.iterate
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
function update!(o::Optimiser, x, Δ)
for opt in o.os
Δ = update!(opt, x, Δ)
end
return Δ
end
mutable struct InvDecay
gamma::Float64
state::IdDict
end
InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
function update!(o::InvDecay, x, Δ)
γ = o.gamma
n = get!(o.state, x, 1)
Δ .*= 1 / (1 + γ * n)
o.state[x] = n + 1
return Δ
end
mutable struct ExpDecay
eta::Float64
decay::Float64
step::Int64
clip::Float64
current::IdDict
end
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
function update!(o::ExpDecay, x, Δ)
η, s, decay = o.eta, o.step, o.decay
n = o.current[x] = get(o.current, x, 0) + 1
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
η = max(η * decay^(s / n), o.clip)
o.eta = η
end
@. Δ *= decay
end
mutable struct WeightDecay
wd::Real
end
WeightDecay() = WeightDecay(0)
function update!(o::WeightDecay, x, Δ)
wd = o.wd
@. Δ += wd * x
end end

View File

@ -1,7 +1,16 @@
using Juno using Juno
using Flux.Tracker: back! using Flux.Tracker: data, grad, back!
import Base.depwarn import Base.depwarn
function update!(opt, xs)
for x in xs
Δ = update!(opt, x.data, x.grad)
x.data .-= Δ
Δ .= 0
end
end
# Callback niceties
runall(f) = f runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs) runall(fs::AbstractVector) = () -> foreach(call, fs)
@ -35,7 +44,7 @@ function stop()
end end
""" """
train!(loss, data, opt) train!(model, loss, data, opt)
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt`. backpropagation and calls the optimizer `opt`.
@ -44,7 +53,7 @@ Takes a callback as keyword argument `cb`. For example, this will print "trainin
every 10 seconds: every 10 seconds:
```julia ```julia
Flux.train!(loss, data, opt, Flux.train!(model, loss, data, opt,
cb = throttle(() -> println("training"), 10)) cb = throttle(() -> println("training"), 10))
``` ```
@ -52,14 +61,14 @@ The callback can return `:stop` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
""" """
function train!(loss, data, opt; cb = () -> ()) function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb) cb = runall(cb)
opt = runall(opt) opt = runall(opt)
@progress for d in data @progress for d in data
try try
l = loss(d...) l = loss(d...)
@interrupts back!(l) @interrupts back!(l)
opt() update!(opt, ps)
if cb() == :stop if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop) depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break break

View File

@ -69,15 +69,28 @@ end
# Out-of-place gradients # Out-of-place gradients
struct Params struct Params
params::IdSet order::Vector{Any}
Params(xs) = new(IdSet(xs)) params::IdSet{Any}
Params() = new([], IdSet())
end end
@forward Params.params Base.iterate, Base.length @forward Params.order Base.iterate, Base.length
function Base.push!(ps::Params, x)
if !(x in ps.params)
push!(ps.order, x)
push!(ps.params, x)
end
return ps
end
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
Params(xs) = push!(Params(), xs...)
function Base.show(io::IO, ps::Params) function Base.show(io::IO, ps::Params)
print(io, "Params([") print(io, "Params([")
join(io, ps.params, ", ") join(io, ps.order, ", ")
print(io, "])") print(io, "])")
end end

View File

@ -7,6 +7,7 @@ Base.eltype(::IdSet{T}) where T = T
IdSet() = IdSet{Any}() IdSet() = IdSet{Any}()
Base.push!(s::IdSet) = s
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s) Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s) Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
Base.in(x, s::IdSet) = haskey(s.dict, x) Base.in(x, s::IdSet) = haskey(s.dict, x)

View File

@ -40,7 +40,7 @@ function prefor(f, x; seen = IdSet())
end end
function params(m) function params(m)
ps = [] ps = Params()
prefor(p -> prefor(p ->
Tracker.istracked(p) && Tracker.isleaf(p) && Tracker.istracked(p) && Tracker.isleaf(p) &&
!any(p -> p === p, ps) && push!(ps, p), !any(p -> p === p, ps) && push!(ps, p),

View File

@ -3,14 +3,37 @@ using Flux.Tracker
using Test using Test
@testset "Optimise" begin @testset "Optimise" begin
w = randn(10, 10) w = randn(10, 10)
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM] @testset for Opt in [ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
w = param(randn(10, 10)) w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w]) opt = Opt(0.001)
for t=1:10^5 if opt isa Descent || opt isa ADAGrad
opt = Opt(0.1)
end
if opt isa ADADelta
opt = Opt(0.9)
end
for t = 1: 10^5
l = loss(rand(10)) l = loss(rand(10))
back!(l) back!(l)
opt() delta = Optimise.update!(opt, w.data, w.grad)
w.data .-= delta
end
@test Flux.mse(w, w) < 0.01
end
end
@testset "Optimiser" begin
w = randn(10, 10)
@testset for Opt in [InvDecay, WeightDecay, ExpDecay]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Optimiser(Opt(), ADAM(0.001))
for t = 1:10^5
l = loss(rand(10))
back!(l)
delta = Optimise.update!(opt, w.data, w.grad)
w.data .-= delta
end end
@test Flux.mse(w, w) < 0.01 @test Flux.mse(w, w) < 0.01
end end
@ -21,9 +44,10 @@ end
l = param(1) l = param(1)
Flux.train!(() -> (sleep(0.1); i += 1; l), Flux.train!(() -> (sleep(0.1); i += 1; l),
(),
Iterators.repeated((), 100), Iterators.repeated((), 100),
()->(), Descent(),
cb = Flux.throttle(() -> (i > 3 && stop()), 1)) cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
@test 3 < i < 50 @test 3 < i < 50
end end