diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 76b90311..b6f18532 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -2,10 +2,11 @@ module Optimise export train!, Descent, ADAM, Momentum, Nesterov, RMSProp, - ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, + ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, InvDecay, ExpDecay, stop, StopException, Compose include("optimisers.jl") include("train.jl") +include("deprecations.jl") end \ No newline at end of file diff --git a/src/optimise/deprecations.jl b/src/optimise/deprecations.jl new file mode 100644 index 00000000..6a297619 --- /dev/null +++ b/src/optimise/deprecations.jl @@ -0,0 +1,128 @@ +using Base: depwarn + +function check_decay(opt, decay) + if decay == 0. + opt = opt + else + if opt isa ADAMW + opt = Compose(opt, DescentWeightDecay(1, decay)) + else + opt = Compose(opt, InvDecay(decay)) + end + end + opt +end + +# legacy update rule +function updaterule(opt, ps) + () -> begin + for p in ps + delta = update!(opt, p.data, p.grad) + p.data .-= delta + end + end +end + +function Descent(params::AbstractArray, η = 0.1; decay = 0.) + depwarn("Descent(ps::Param) is deprecated; use Descent(η::Float64) instead", :Descent) + + ps = params + opt = Descent(η) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.) + depwarn("Momentum(ps::Param) is deprecated; use Momentum(η::Float64) instead", :Momentum) + + ps = params + opt = Momentum(η, ρ) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) + depwarn("Nesterov(ps::Param) is deprecated; use Nesterov(η::Float64) instead", :Nesterov) + + ps = params + opt = Nesterov(η, ρ) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) + depwarn("RMSProp(ps::Param) is deprecated; use RMSProp(η::Float64) instead", :RMSProp) + + ps = params + opt = RMSProp(η, ρ) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) + depwarn("ADAM(ps::Param) is deprecated; use ADAM(η::Float64) instead", :ADAM) + + ps = params + β = (β1, β2) + opt = ADAM(η, β) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.) + depwarn("ADAGrad(ps::Param) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad) + + ps = params + opt = ADAGrad(η) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.) + depwarn("ADADelta(ps::Param) is deprecated; use ADADelta(η::Float64) instead", :ADADelta) + + ps = params + opt = ADADelta(ρ) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) + depwarn("AdaMax(ps::Param) is deprecated; use AdaMax(η::Float64) instead", :AdaMax) + + ps = params + β = (β1, β2) + opt = AdaMax(η, β) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) + depwarn("AMSGrad(ps::Param) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad) + + ps = params + β = (β1, β2) + opt = AMSGrad(η, β) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) + depwarn("NADAM(ps::Param) is deprecated; use NADAM(η::Float64) instead", :NADAM) + + ps = params + β = (β1, β2) + opt = NADAM(η, β) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) + depwarn("ADAMW(ps::Param) is deprecated; use ADAMW(η::Float64) instead", :ADAMW) + + ps = params + β = (β1, β2) + opt = ADAMW(η, β) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end \ No newline at end of file diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 18d8336b..4005db4f 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -1,5 +1,6 @@ using Flux using Base: @get! +using MacroTools: @forward const ϵ = 1e-8 @@ -15,6 +16,7 @@ mutable struct Descent eta::Float64 end +Descent(η = 0.1) = Descent(η) function update!(o::Descent, x, Δ) Δ .*= o.eta end @@ -30,7 +32,7 @@ mutable struct Momentum velocity::IdDict end -Momentum(η, ρ = 0.9) = Momentum(η, ρ, IdDict()) +Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict()) function update!(o::Momentum, x, Δ) η, ρ = o.eta, o.rho @@ -50,7 +52,7 @@ mutable struct Nesterov velocity::IdDict end -Nesterov(η, ρ = 0.9) = Nesterov(η, ρ, IdDict()) +Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict()) function update!(o::Nesterov, x, Δ) η, ρ = o.eta, o.rho @@ -219,10 +221,46 @@ function update!(o::NADAM, x, Δ) return Δ end +""" + ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + +[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. +""" +ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Compose(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, γ_decay)) + +# Compose optimizers + +""" + `Compose(Compose(...), ...)` + +Compose optimizers to support inbuilt or custom gradient updates while fitting the loss. + +Example:\n\n +`Compose(ADAM(), Compose(RMSProp(0.001), ExpDecay(0.02)))` +""" mutable struct Compose os::Vector{Any} end +Compose(o...) = Compose(flattenCompose(o...)) + +@forward Compose.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex! +@forward Compose.os Base.iterate + +Base.getindex(c::Compose, i::AbstractArray) = Compose(c.os[i]...) + +function flattenCompose(o...) + res = [] + for opt in o + if opt isa Compose + push!(res, flattenCompose(opt.os...)...) + else + push!(res, opt) + end + end + return res +end + function update!(o::Compose, x, Δ) for opt in o.os Δ = update!(opt, x, Δ) @@ -256,3 +294,15 @@ function update!(o::ExpDecay, x, Δ) γ = o.gamma @. Δ += γ * x end + +mutable struct DescentWeightDecay + eta::Real + gamma::Real +end + +DescentWeightDecay(η = 1, γ = 0) = DescentWeightDecay(η, γ) +function update!(o::DescentWeightDecay, x, Δ) + η, γ = o.eta, o.gamma + @. x = x - η * (Δ + γ * x) + Δ +end