diff --git a/.travis.yml b/.travis.yml index c03f1de7..e44b3541 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,6 @@ os: - linux # - osx julia: - - 0.7 - 1.0 - nightly # uncomment the following lines to override the default test script diff --git a/REQUIRE b/REQUIRE index ad3306d6..feec31c3 100644 --- a/REQUIRE +++ b/REQUIRE @@ -1,4 +1,4 @@ -julia 0.7 +julia 1.0 Juno MacroTools 0.3.3 NNlib diff --git a/src/Flux.jl b/src/Flux.jl index f4f2db62..48847fbe 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -19,8 +19,9 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs -export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, - RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM +export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, + ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, + ADAMW, InvDecay, ExpDecay, WeightDecay include("utils.jl") include("onehot.jl") diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index c4828c9e..5bb38d1e 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,23 +1,12 @@ module Optimise export train!, - SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, - RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException - -struct Param{T} - x::T - Δ::T -end - -Param(x::AbstractArray) = Param(x, zero(x)) + SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, + ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, + InvDecay, ExpDecay, WeightDecay, stop, Optimiser include("optimisers.jl") -include("interface.jl") include("train.jl") - -using Flux.Tracker: TrackedArray - -Param(x::TrackedArray) = Param(x.data, x.grad) -# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad) +include("deprecations.jl") end diff --git a/src/optimise/deprecations.jl b/src/optimise/deprecations.jl new file mode 100644 index 00000000..40c695b6 --- /dev/null +++ b/src/optimise/deprecations.jl @@ -0,0 +1,125 @@ +using Base: depwarn + +check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay)) + +# legacy update rule +updaterule(opt, ps) = () -> update!(p, ps) + +function SGD(params::AbstractArray, η = 0.1; decay = 0.) + depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD) + + ps = params + opt = Descent(η) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.) + depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum) + + ps = params + opt = Momentum(η, ρ) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) + depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov) + + ps = params + opt = Nesterov(η, ρ) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) + depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp) + + ps = params + opt = RMSProp(η, ρ) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) + depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM) + + ps = params + β = (β1, β2) + opt = ADAM(η, β) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.) + depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad) + + ps = params + opt = ADAGrad(η) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.) + depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta) + + ps = params + opt = ADADelta(ρ) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) + depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax) + + ps = params + β = (β1, β2) + opt = AdaMax(η, β) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) + depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad) + + ps = params + β = (β1, β2) + opt = AMSGrad(η, β) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) + depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM) + + ps = params + β = (β1, β2) + opt = NADAM(η, β) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) + depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW) + + ps = params + β = (β1, β2) + opt = ADAMW(η, β) + opt = check_decay(opt, decay) + decay != 0 && (opt = Optimiser(opt, WeightDecay(decay))) + updaterule(opt, ps) +end + +# Old training loop + +struct OldOptimiser + func +end + +update!(opt::OldOptimiser, ps) = opt.func() + +# Train function +function train!(loss, data, opt; cb = () -> ()) + depwarn("train!(loss, data, opt) is deprecated; use train!(loss, params, data, opt) instead", :train!) + train!(loss, (), data, OldOptimiser(opt); cb = cb) +end diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl deleted file mode 100644 index 096e2d87..00000000 --- a/src/optimise/interface.jl +++ /dev/null @@ -1,110 +0,0 @@ -call(f, xs...) = f(xs...) - -# note for optimisers: set to zero -# p.Δ at the end of the weights update -function optimiser(ps, fs...) - ps = [Param(p) for p in ps] - fs = map(ps) do p - os = map(f -> f(p), fs) - () -> foreach(call, os) - end - () -> foreach(call, fs) -end - -""" - SGD(params, η = 0.1; decay = 0) - -Classic gradient descent optimiser with learning rate `η`. -For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`. - -Supports inverse decaying learning rate if the `decay` argument is provided. -""" -SGD(ps, η = 0.1; decay = 0) = - optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η)) - -""" - Momentum(params, η = 0.01; ρ = 0.9, decay = 0) - -SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay. -""" -Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) = - optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1)) - -""" - Nesterov(params, η = 0.01; ρ = 0.9, decay = 0) - -SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay. -""" -Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) = - optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1)) - -""" - RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) - -[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) -optimiser. Parameters other than learning rate don't need tuning. Often a good -choice for recurrent networks. -""" -RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) = - optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) - -""" - ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) - -[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser. -""" -ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) - -""" - ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) - -[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. -""" -ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay)) - -""" - AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) - -[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on -the ∞-norm. -""" -AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) - -""" - ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0) - -[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. -Parameters don't need tuning. -""" -ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) = - optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) - -""" - ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0) - -[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need -tuning. -""" -ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) = - optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1)) - -""" - AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) - -[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need -tuning. -""" -AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) - -""" - NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) - -[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other -than learning rate don't need tuning. -""" -NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 1f7a7c9c..2accc4bc 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -1,130 +1,304 @@ -function descent(p::Param, η::Real) - function () - @. p.x -= η * p.Δ - @. p.Δ = 0 +using Flux +using Base: @get! +using MacroTools: @forward + +const ϵ = 1e-8 + +# TODO: should use weak refs + +""" + Descent(η) + +Classic gradient descent optimiser with learning rate `η`. +For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`. +""" +mutable struct Descent + eta::Float64 +end + +Descent() = Descent(0.1) + +function update!(o::Descent, x, Δ) + Δ .*= o.eta +end + +""" + Momentum(params, η = 0.01; ρ = 0.9) + +Gradient descent with learning rate `η` and momentum `ρ`. +""" +mutable struct Momentum + eta::Float64 + rho::Float64 + velocity::IdDict +end + +Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict()) + +function update!(o::Momentum, x, Δ) + η, ρ = o.eta, o.rho + v = get!(o.velocity, x, zero(x))::typeof(x) + @. v = ρ * v - η * Δ + @. Δ = -v +end + +""" + Nesterov(eta, ρ = 0.9) + +Gradient descent with learning rate `η` and Nesterov momentum `ρ`. +""" +mutable struct Nesterov + eta::Float64 + rho::Float64 + velocity::IdDict +end + +Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict()) + +function update!(o::Nesterov, x, Δ) + η, ρ = o.eta, o.rho + v = get!(o.velocity, x, zero(x))::typeof(x) + d = @. ρ^2 * v - (1+ρ) * η * Δ + @. v = ρ*v - η*Δ + @. Δ = -d +end + +""" + RMSProp(η = 0.001, ρ = 0.9) + +[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) +optimiser. Parameters other than learning rate don't need tuning. Often a good +choice for recurrent networks. +""" +mutable struct RMSProp + eta::Float64 + rho::Float64 + acc::IdDict +end + +RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict()) + +function update!(o::RMSProp, x, Δ) + η, ρ = o.eta, o.rho + acc = get!(o.acc, x, zero(x))::typeof(x) + @. acc = ρ * acc + (1 - ρ) * Δ^2 + @. Δ *= η / (√acc + ϵ) +end + +""" + ADAM(η = 0.001, β = (0.9, 0.999)) + +[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser. +""" +mutable struct ADAM + eta::Float64 + beta::Tuple{Float64,Float64} + state::IdDict +end + +ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict()) + +function update!(o::ADAM, x, Δ) + η, β = o.eta, o.beta + mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β)) + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. vt = β[2] * vt + (1 - β[2]) * Δ^2 + @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η + o.state[x] = (mt, vt, βp .* β) + return Δ +end + +""" + AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08) + +[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on +the ∞-norm. +""" +mutable struct AdaMax + eta::Float64 + beta::Tuple{Float64,Float64} + state::IdDict +end + +AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict()) + +function update!(o::AdaMax, x, Δ) + η, β = o.eta, o.beta + mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β)) + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. ut = max(β[2] * ut, abs(Δ)) + @. Δ = (η/(1 - βp[1])) * mt/(ut + ϵ) + o.state[x] = (mt, ut, βp .* β) + return Δ +end + +""" + ADAGrad(η = 0.1; ϵ = 1e-8) + +[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. +Parameters don't need tuning. +""" +mutable struct ADAGrad + eta::Float64 + acc::IdDict +end + +ADAGrad(η = 0.1) = ADAGrad(η, IdDict()) + +function update!(o::ADAGrad, x, Δ) + η = o.eta + acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x) + @. acc += Δ^2 + @. Δ *= η / √(acc + ϵ) +end + +""" + ADADelta(ρ = 0.9, ϵ = 1e-8) + +[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need +tuning. +""" +mutable struct ADADelta + rho::Float64 + state::IdDict +end + +ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict()) + +function update!(o::ADADelta, x, Δ) + ρ = o.rho + acc, Δacc = get!(o.state, x, (zero(x), zero(x))) + @. acc = ρ * acc + (1 - ρ) * Δ^2 + @. Δ *= √(Δacc + ϵ) / √(acc + ϵ) + @. Δacc = ρ * Δacc + (1 - ρ) * Δ^2 + return Δ +end + +""" + AMSGrad(η = 0.001, β = (0.9, 0.999)) + +[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need +tuning. +""" +mutable struct AMSGrad + eta::Float64 + beta::Tuple{Float64, Float64} + state::IdDict +end + +AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict()) + +function update!(o::AMSGrad, x, Δ) + η, β = o.eta, o.beta + mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x)))) + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2 + @. v̂t = max.(v̂t, vt) + @. Δ = η * mt / √v̂t +end + +""" + NADAM(η = 0.001, β = (0.9, 0.999)) + +[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser. Parameters don't need +tuning. +""" +mutable struct NADAM + eta::Float64 + beta::Tuple{Float64, Float64} + state::IdDict +end + +NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict()) + +function update!(o::NADAM, x, Δ) + η, β = o.eta, o.beta + β1p, β2p = o.beta + mt, vt = get!(o.state, x, (zero(x), zero(x))) + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. vt = β[2] * vt + (1 - β[2]) * Δ^2 + @. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / √(vt * β[2] / (1 - β2p) + ϵ) * η + o.state[x] = (mt, vt, (β1p * β[1], β2p * β[2])) + return Δ +end + +""" + ADAMW((η = 0.001, β = (0.9, 0.999), decay = 0) + +[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. +""" +ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) = + Optimiser(ADAM(η, β), WeightDecay(wd)) + +# Compose optimizers + +""" + Optimiser(a, b, c...) + +Combine several optimisers into one; each optimiser produces a modified gradient +that will be fed into the next, and this is finally applied to the parameter as +usual. +""" +mutable struct Optimiser + os::Vector{Any} +end + +Optimiser(o...) = Optimiser(Any[o...]) + +@forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex! +@forward Optimiser.os Base.iterate + +Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...) + +function update!(o::Optimiser, x, Δ) + for opt in o.os + Δ = update!(opt, x, Δ) end + return Δ end -# Ref: https://arxiv.org/abs/1711.05101.pdf -function descentweightdecay(p::Param, η::Real, γ::Real) - function () - @. p.x = p.x - η * (p.Δ + γ * p.x) - @. p.Δ = 0 +mutable struct InvDecay + gamma::Float64 + state::IdDict +end + +InvDecay(γ = 0.001) = InvDecay(γ, IdDict()) + +function update!(o::InvDecay, x, Δ) + γ = o.gamma + n = get!(o.state, x, 1) + Δ .*= 1 / (1 + γ * n) + o.state[x] = n + 1 + return Δ +end + +mutable struct ExpDecay + eta::Float64 + decay::Float64 + step::Int64 + clip::Float64 + current::IdDict +end + +ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict()) + +function update!(o::ExpDecay, x, Δ) + η, s, decay = o.eta, o.step, o.decay + n = o.current[x] = get(o.current, x, 0) + 1 + if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1 + η = max(η * decay^(s / n), o.clip) + o.eta = η end + @. Δ *= decay end -function momentum(p::Param, ρ, η) - v = zero(p.x) - function () - @. v = ρ * v - η * p.Δ - @. p.Δ = -v - end +mutable struct WeightDecay + wd::Real end -# Ref. https://arxiv.org/pdf/1212.0901.pdf -function nesterov(p::Param, ρ, η) - v = zero(p.x) - function () - d = @. ρ^2 * v - (1+ρ) * η * p.Δ - @. v = ρ*v - η*p.Δ - @. p.Δ = -d - end -end +WeightDecay() = WeightDecay(0) -function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) - acc = zero(p.x) - function () - @. acc = ρ * acc + (1 - ρ) * p.Δ^2 - @. p.Δ *= η / √(acc + ϵ) - end -end - -function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) - acc = zero(p.x) .+ ϵ - function () - @. acc += p.Δ^2 - @. p.Δ *= η / √(acc + ϵ) - end -end - -function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8) - acc = zero(p.x) - Δacc = zero(p.x) - function () - @. acc = ρ * acc + (1 - ρ) * p.Δ^2 - @. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ) - @. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2 - end -end - -function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) - mt = zero(p.x) - vt = zero(p.x) - β1p, β2p = β1, β2 - function () - @. mt = β1 * mt + (1 - β1) * p.Δ - @. vt = β2 * vt + (1 - β2) * p.Δ^2 - @. p.Δ = mt / (1 - β1p) / √(vt / (1 - β2p) + ϵ) * η - β1p *= β1 - β2p *= β2 - end -end - -function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) - mt = zero(p.x) - ut = zero(p.x) - β1p = β1 - function () - @. mt = β1 * mt + (1 - β1) * p.Δ - @. ut = max(β2 * ut, abs(p.Δ)) - @. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ) - β1p *= β1 - end -end - -function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) - mt = zero(p.x) - vt = zero(p.x) .+ ϵ - v̂t = zero(p.x) .+ ϵ - function () - @. mt = β1 * mt + (1 - β1) * p.Δ - @. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 - @. v̂t = max.(v̂t, vt) - @. p.Δ = η * mt / √v̂t - end -end - -function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) - mt = zero(p.x) - vt = zero(p.x) - β1p, β2p = β1, β2 - function () - @. mt = β1 * mt + (1 - β1) * p.Δ - @. vt = β2 * vt + (1 - β2) * p.Δ^2 - @. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / √(vt * β2 / (1 - β2p) + ϵ) * η - β1p *= β1 - β2p *= β2 - end -end - -clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh) - -function expdecay(p::Param, γ::Real) - if γ != 0 - return () -> p.Δ .+= γ .* p.x - else - return () -> nothing - end -end - -function invdecay(p::Param, γ::Real) - if γ != 0 - n = 0 - return () -> begin - p.Δ .*= 1 / (1 + γ * n) - n += 1 - end - else - return () -> nothing - end +function update!(o::WeightDecay, x, Δ) + wd = o.wd + @. Δ += wd * x end diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 09893873..28bdf27b 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,7 +1,16 @@ using Juno -using Flux.Tracker: back! +using Flux.Tracker: data, grad, back! import Base.depwarn +function update!(opt, xs) + for x in xs + Δ = update!(opt, x.data, x.grad) + x.data .-= Δ + Δ .= 0 + end +end + +# Callback niceties runall(f) = f runall(fs::AbstractVector) = () -> foreach(call, fs) @@ -35,7 +44,7 @@ function stop() end """ - train!(loss, data, opt) + train!(model, loss, data, opt) For each datapoint `d` in `data` computes the gradient of `loss(d...)` through backpropagation and calls the optimizer `opt`. @@ -44,7 +53,7 @@ Takes a callback as keyword argument `cb`. For example, this will print "trainin every 10 seconds: ```julia -Flux.train!(loss, data, opt, +Flux.train!(model, loss, data, opt, cb = throttle(() -> println("training"), 10)) ``` @@ -52,14 +61,14 @@ The callback can return `:stop` to interrupt the training loop. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ -function train!(loss, data, opt; cb = () -> ()) +function train!(loss, ps, data, opt; cb = () -> ()) cb = runall(cb) opt = runall(opt) @progress for d in data try l = loss(d...) @interrupts back!(l) - opt() + update!(opt, ps) if cb() == :stop depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop) break diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 17bf42df..af130dd3 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -69,15 +69,28 @@ end # Out-of-place gradients struct Params - params::IdSet - Params(xs) = new(IdSet(xs)) + order::Vector{Any} + params::IdSet{Any} + Params() = new([], IdSet()) end -@forward Params.params Base.iterate, Base.length +@forward Params.order Base.iterate, Base.length + +function Base.push!(ps::Params, x) + if !(x in ps.params) + push!(ps.order, x) + push!(ps.params, x) + end + return ps +end + +Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps) + +Params(xs) = push!(Params(), xs...) function Base.show(io::IO, ps::Params) print(io, "Params([") - join(io, ps.params, ", ") + join(io, ps.order, ", ") print(io, "])") end diff --git a/src/tracker/idset.jl b/src/tracker/idset.jl index 62570c99..372e262a 100644 --- a/src/tracker/idset.jl +++ b/src/tracker/idset.jl @@ -7,6 +7,7 @@ Base.eltype(::IdSet{T}) where T = T IdSet() = IdSet{Any}() +Base.push!(s::IdSet) = s Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s) Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s) Base.in(x, s::IdSet) = haskey(s.dict, x) diff --git a/src/treelike.jl b/src/treelike.jl index 9b3518d3..88e878c4 100644 --- a/src/treelike.jl +++ b/src/treelike.jl @@ -40,7 +40,7 @@ function prefor(f, x; seen = IdSet()) end function params(m) - ps = [] + ps = Params() prefor(p -> Tracker.istracked(p) && Tracker.isleaf(p) && !any(p′ -> p′ === p, ps) && push!(ps, p), diff --git a/test/optimise.jl b/test/optimise.jl index 502d9ab2..98d06edb 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -3,14 +3,37 @@ using Flux.Tracker using Test @testset "Optimise" begin w = randn(10, 10) - @testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM] + @testset for Opt in [ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum] w′ = param(randn(10, 10)) loss(x) = Flux.mse(w*x, w′*x) - opt = Opt([w′]) - for t=1:10^5 + opt = Opt(0.001) + if opt isa Descent || opt isa ADAGrad + opt = Opt(0.1) + end + if opt isa ADADelta + opt = Opt(0.9) + end + for t = 1: 10^5 l = loss(rand(10)) back!(l) - opt() + delta = Optimise.update!(opt, w′.data, w′.grad) + w′.data .-= delta + end + @test Flux.mse(w, w′) < 0.01 + end +end + +@testset "Optimiser" begin + w = randn(10, 10) + @testset for Opt in [InvDecay, WeightDecay, ExpDecay] + w′ = param(randn(10, 10)) + loss(x) = Flux.mse(w*x, w′*x) + opt = Optimiser(Opt(), ADAM(0.001)) + for t = 1:10^5 + l = loss(rand(10)) + back!(l) + delta = Optimise.update!(opt, w′.data, w′.grad) + w′.data .-= delta end @test Flux.mse(w, w′) < 0.01 end @@ -21,9 +44,10 @@ end l = param(1) Flux.train!(() -> (sleep(0.1); i += 1; l), + (), Iterators.repeated((), 100), - ()->(), - cb = Flux.throttle(() -> (i > 3 && stop()), 1)) + Descent(), + cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1)) @test 3 < i < 50 end