fixed weight decay definition

This commit is contained in:
Dhairya Gandhi 2018-10-11 10:07:16 +05:30
parent 0f2019eba5
commit fe8c147f72
6 changed files with 35 additions and 26 deletions

View File

@ -21,7 +21,7 @@ using .Optimise
using .Optimise: @epochs using .Optimise: @epochs
export Descent, ADAM, Momentum, Nesterov, RMSProp, export Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
InvDecay, ExpDecay ADAMW, InvDecay, ExpDecay, WeightDecay, DescentWeightDecay
include("utils.jl") include("utils.jl")
include("onehot.jl") include("onehot.jl")

View File

@ -3,7 +3,7 @@ module Optimise
export train!, export train!,
Descent, ADAM, Momentum, Nesterov, RMSProp, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
InvDecay, ExpDecay, stop, Compose InvDecay, ExpDecay, WeightDecay, stop, Optimiser
include("optimisers.jl") include("optimisers.jl")
include("train.jl") include("train.jl")

View File

@ -5,9 +5,9 @@ function check_decay(opt, decay)
opt = opt opt = opt
else else
if opt isa ADAMW if opt isa ADAMW
opt = Compose(opt, DescentWeightDecay(1, decay)) opt = Optimiser(opt, DescentWeightDecay(1, decay))
else else
opt = Compose(opt, InvDecay(decay)) opt = Optimiser(opt, InvDecay(decay))
end end
end end
opt opt
@ -126,3 +126,9 @@ function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay
opt = check_decay(opt, decay) opt = check_decay(opt, decay)
updaterule(opt, ps) updaterule(opt, ps)
end end
# Train function
function train!(loss::Function, data, opt; cb = () -> ())
depwarn("train!(loss, data, opt; cb) is deprecated; use train!(model, data, loss, opt; cb) instead", :train)
train!(opt.ps, loss, data, opt.opt; cb = cb)
end

View File

@ -85,7 +85,7 @@ function update!(o::RMSProp, x, Δ)
end end
""" """
ADAM(η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08) ADAM(η = 0.001, β = (0.9, 0.999))
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser. [ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
""" """
@ -226,28 +226,29 @@ end
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. [ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
""" """
ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Compose(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, γ_decay)) ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Optimiser(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, γ_decay))
# Compose optimizers # Compose optimizers
""" """
Compose(a, b, c...) Optimiser(a, b, c...)
Combine several optimisers into one; each optimiser produces a modified gradient Combine several optimisers into one; each optimiser produces a modified gradient
that will be fed into the next, and this is finally applied to the parameter as that will be fed into the next, and this is finally applied to the parameter as
usual. usual.
""" """
mutable struct Compose mutable struct Optimiser
os::Vector{Any} os::Vector{Any}
Compose(o...) = Compose(Any[o...])
end end
@forward Compose.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex! Optimiser(o...) = Optimiser(Any[o...])
@forward Compose.os Base.iterate
Base.getindex(c::Compose, i::AbstractArray) = Compose(c.os[i]...) @forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
@forward Optimiser.os Base.iterate
function update!(o::Compose, x, Δ) Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
function update!(o::Optimiser, x, Δ)
for opt in o.os for opt in o.os
Δ = update!(opt, x, Δ) Δ = update!(opt, x, Δ)
end end
@ -281,14 +282,15 @@ function update!(o::ExpDecay, x, Δ)
@. Δ += γ * x @. Δ += γ * x
end end
mutable struct DescentWeightDecay mutable struct WeightDecay
eta::Real eta::Real
gamma::Real wd::Real
end end
DescentWeightDecay(η = 1) = DescentWeightDecay(η, 0) WeightDecay(η = 1) = WeightDecay(η, 0)
function update!(o::DescentWeightDecay, x, Δ) function update!(o::WeightDecay, x, Δ)
η, γ = o.eta, o.gamma η, wd = o.eta, o.wd
@. x = x - η * (Δ + γ * x) @. Δ += wd * x
Δ
end end
DescentWeightDecay(η = 0.1, γ = 0) = Optimiser(WeightDecay(), Descent(η))

View File

@ -45,7 +45,7 @@ function stop()
end end
""" """
train!(loss, data, opt) train!(model, loss, data, opt)
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt`. backpropagation and calls the optimizer `opt`.
@ -54,7 +54,7 @@ Takes a callback as keyword argument `cb`. For example, this will print "trainin
every 10 seconds: every 10 seconds:
```julia ```julia
Flux.train!(loss, data, opt, Flux.train!(model, loss, data, opt,
cb = throttle(() -> println("training"), 10)) cb = throttle(() -> println("training"), 10))
``` ```
@ -62,14 +62,14 @@ The callback can return `:stop` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
""" """
function train!(loss, data, opt; cb = () -> ()) function train!(ps::Array, loss, data, opt; cb = () -> ())
cb = runall(cb) cb = runall(cb)
opt = runall(opt) opt = runall(opt)
@progress for d in data @progress for d in data
try try
l = loss(d...) l = loss(d...)
@interrupts back!(l) @interrupts back!(l)
opt() foreach(x -> x.data .-= update!(opt, x.data, x.grad), ps)
if cb() == :stop if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop) depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break break
@ -83,6 +83,7 @@ function train!(loss, data, opt; cb = () -> ())
end end
end end
end end
train!(model, loss, data, opt; cb = () -> ()) = train!(params(model), loss, data, opt; cb = cb)
""" """
@epochs N body @epochs N body

View File

@ -23,12 +23,12 @@ using Test
end end
end end
@testset "Compose" begin @testset "Optimiser" begin
w = randn(10, 10) w = randn(10, 10)
@testset for Opt in [InvDecay, ExpDecay] @testset for Opt in [InvDecay, ExpDecay]
w = param(randn(10, 10)) w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)
opt = Compose(vec([Opt(), ADAM(0.001)])) opt = Optimiser(Opt(), ADAM(0.001))
for t = 1:10^5 for t = 1:10^5
l = loss(rand(10)) l = loss(rand(10))
back!(l) back!(l)