fixed weight decay definition

This commit is contained in:
Dhairya Gandhi 2018-10-11 10:07:16 +05:30
parent 0f2019eba5
commit fe8c147f72
6 changed files with 35 additions and 26 deletions

View File

@ -21,7 +21,7 @@ using .Optimise
using .Optimise: @epochs
export Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
InvDecay, ExpDecay
ADAMW, InvDecay, ExpDecay, WeightDecay, DescentWeightDecay
include("utils.jl")
include("onehot.jl")

View File

@ -3,7 +3,7 @@ module Optimise
export train!,
Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
InvDecay, ExpDecay, stop, Compose
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
include("optimisers.jl")
include("train.jl")

View File

@ -5,9 +5,9 @@ function check_decay(opt, decay)
opt = opt
else
if opt isa ADAMW
opt = Compose(opt, DescentWeightDecay(1, decay))
opt = Optimiser(opt, DescentWeightDecay(1, decay))
else
opt = Compose(opt, InvDecay(decay))
opt = Optimiser(opt, InvDecay(decay))
end
end
opt
@ -126,3 +126,9 @@ function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
# Train function
function train!(loss::Function, data, opt; cb = () -> ())
depwarn("train!(loss, data, opt; cb) is deprecated; use train!(model, data, loss, opt; cb) instead", :train)
train!(opt.ps, loss, data, opt.opt; cb = cb)
end

View File

@ -85,7 +85,7 @@ function update!(o::RMSProp, x, Δ)
end
"""
ADAM(η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)
ADAM(η = 0.001, β = (0.9, 0.999))
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
"""
@ -226,28 +226,29 @@ end
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
"""
ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Compose(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, γ_decay))
ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Optimiser(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, γ_decay))
# Compose optimizers
"""
Compose(a, b, c...)
Optimiser(a, b, c...)
Combine several optimisers into one; each optimiser produces a modified gradient
that will be fed into the next, and this is finally applied to the parameter as
usual.
"""
mutable struct Compose
mutable struct Optimiser
os::Vector{Any}
Compose(o...) = Compose(Any[o...])
end
@forward Compose.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
@forward Compose.os Base.iterate
Optimiser(o...) = Optimiser(Any[o...])
Base.getindex(c::Compose, i::AbstractArray) = Compose(c.os[i]...)
@forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
@forward Optimiser.os Base.iterate
function update!(o::Compose, x, Δ)
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
function update!(o::Optimiser, x, Δ)
for opt in o.os
Δ = update!(opt, x, Δ)
end
@ -281,14 +282,15 @@ function update!(o::ExpDecay, x, Δ)
@. Δ += γ * x
end
mutable struct DescentWeightDecay
mutable struct WeightDecay
eta::Real
gamma::Real
wd::Real
end
DescentWeightDecay(η = 1) = DescentWeightDecay(η, 0)
function update!(o::DescentWeightDecay, x, Δ)
η, γ = o.eta, o.gamma
@. x = x - η * (Δ + γ * x)
Δ
WeightDecay(η = 1) = WeightDecay(η, 0)
function update!(o::WeightDecay, x, Δ)
η, wd = o.eta, o.wd
@. Δ += wd * x
end
DescentWeightDecay(η = 0.1, γ = 0) = Optimiser(WeightDecay(), Descent(η))

View File

@ -45,7 +45,7 @@ function stop()
end
"""
train!(loss, data, opt)
train!(model, loss, data, opt)
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt`.
@ -54,7 +54,7 @@ Takes a callback as keyword argument `cb`. For example, this will print "trainin
every 10 seconds:
```julia
Flux.train!(loss, data, opt,
Flux.train!(model, loss, data, opt,
cb = throttle(() -> println("training"), 10))
```
@ -62,14 +62,14 @@ The callback can return `:stop` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
"""
function train!(loss, data, opt; cb = () -> ())
function train!(ps::Array, loss, data, opt; cb = () -> ())
cb = runall(cb)
opt = runall(opt)
@progress for d in data
try
l = loss(d...)
@interrupts back!(l)
opt()
foreach(x -> x.data .-= update!(opt, x.data, x.grad), ps)
if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break
@ -83,6 +83,7 @@ function train!(loss, data, opt; cb = () -> ())
end
end
end
train!(model, loss, data, opt; cb = () -> ()) = train!(params(model), loss, data, opt; cb = cb)
"""
@epochs N body

View File

@ -23,12 +23,12 @@ using Test
end
end
@testset "Compose" begin
@testset "Optimiser" begin
w = randn(10, 10)
@testset for Opt in [InvDecay, ExpDecay]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Compose(vec([Opt(), ADAM(0.001)]))
opt = Optimiser(Opt(), ADAM(0.001))
for t = 1:10^5
l = loss(rand(10))
back!(l)