fixed weight decay definition
This commit is contained in:
parent
0f2019eba5
commit
fe8c147f72
|
@ -21,7 +21,7 @@ using .Optimise
|
|||
using .Optimise: @epochs
|
||||
export Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
|
||||
InvDecay, ExpDecay
|
||||
ADAMW, InvDecay, ExpDecay, WeightDecay, DescentWeightDecay
|
||||
|
||||
include("utils.jl")
|
||||
include("onehot.jl")
|
||||
|
|
|
@ -3,7 +3,7 @@ module Optimise
|
|||
export train!,
|
||||
Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
|
||||
InvDecay, ExpDecay, stop, Compose
|
||||
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
|
||||
|
||||
include("optimisers.jl")
|
||||
include("train.jl")
|
||||
|
|
|
@ -5,9 +5,9 @@ function check_decay(opt, decay)
|
|||
opt = opt
|
||||
else
|
||||
if opt isa ADAMW
|
||||
opt = Compose(opt, DescentWeightDecay(1, decay))
|
||||
opt = Optimiser(opt, DescentWeightDecay(1, decay))
|
||||
else
|
||||
opt = Compose(opt, InvDecay(decay))
|
||||
opt = Optimiser(opt, InvDecay(decay))
|
||||
end
|
||||
end
|
||||
opt
|
||||
|
@ -126,3 +126,9 @@ function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay
|
|||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
# Train function
|
||||
function train!(loss::Function, data, opt; cb = () -> ())
|
||||
depwarn("train!(loss, data, opt; cb) is deprecated; use train!(model, data, loss, opt; cb) instead", :train)
|
||||
train!(opt.ps, loss, data, opt.opt; cb = cb)
|
||||
end
|
|
@ -85,7 +85,7 @@ function update!(o::RMSProp, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
ADAM(η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)
|
||||
ADAM(η = 0.001, β = (0.9, 0.999))
|
||||
|
||||
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
||||
"""
|
||||
|
@ -226,28 +226,29 @@ end
|
|||
|
||||
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
||||
"""
|
||||
ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Compose(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, γ_decay))
|
||||
ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Optimiser(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, γ_decay))
|
||||
|
||||
# Compose optimizers
|
||||
|
||||
"""
|
||||
Compose(a, b, c...)
|
||||
Optimiser(a, b, c...)
|
||||
|
||||
Combine several optimisers into one; each optimiser produces a modified gradient
|
||||
that will be fed into the next, and this is finally applied to the parameter as
|
||||
usual.
|
||||
"""
|
||||
mutable struct Compose
|
||||
mutable struct Optimiser
|
||||
os::Vector{Any}
|
||||
Compose(o...) = Compose(Any[o...])
|
||||
end
|
||||
|
||||
@forward Compose.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
|
||||
@forward Compose.os Base.iterate
|
||||
Optimiser(o...) = Optimiser(Any[o...])
|
||||
|
||||
Base.getindex(c::Compose, i::AbstractArray) = Compose(c.os[i]...)
|
||||
@forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
|
||||
@forward Optimiser.os Base.iterate
|
||||
|
||||
function update!(o::Compose, x, Δ)
|
||||
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
|
||||
|
||||
function update!(o::Optimiser, x, Δ)
|
||||
for opt in o.os
|
||||
Δ = update!(opt, x, Δ)
|
||||
end
|
||||
|
@ -281,14 +282,15 @@ function update!(o::ExpDecay, x, Δ)
|
|||
@. Δ += γ * x
|
||||
end
|
||||
|
||||
mutable struct DescentWeightDecay
|
||||
mutable struct WeightDecay
|
||||
eta::Real
|
||||
gamma::Real
|
||||
wd::Real
|
||||
end
|
||||
|
||||
DescentWeightDecay(η = 1) = DescentWeightDecay(η, 0)
|
||||
function update!(o::DescentWeightDecay, x, Δ)
|
||||
η, γ = o.eta, o.gamma
|
||||
@. x = x - η * (Δ + γ * x)
|
||||
Δ
|
||||
WeightDecay(η = 1) = WeightDecay(η, 0)
|
||||
function update!(o::WeightDecay, x, Δ)
|
||||
η, wd = o.eta, o.wd
|
||||
@. Δ += wd * x
|
||||
end
|
||||
|
||||
DescentWeightDecay(η = 0.1, γ = 0) = Optimiser(WeightDecay(), Descent(η))
|
|
@ -45,7 +45,7 @@ function stop()
|
|||
end
|
||||
|
||||
"""
|
||||
train!(loss, data, opt)
|
||||
train!(model, loss, data, opt)
|
||||
|
||||
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
|
||||
backpropagation and calls the optimizer `opt`.
|
||||
|
@ -54,7 +54,7 @@ Takes a callback as keyword argument `cb`. For example, this will print "trainin
|
|||
every 10 seconds:
|
||||
|
||||
```julia
|
||||
Flux.train!(loss, data, opt,
|
||||
Flux.train!(model, loss, data, opt,
|
||||
cb = throttle(() -> println("training"), 10))
|
||||
```
|
||||
|
||||
|
@ -62,14 +62,14 @@ The callback can return `:stop` to interrupt the training loop.
|
|||
|
||||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||
"""
|
||||
function train!(loss, data, opt; cb = () -> ())
|
||||
function train!(ps::Array, loss, data, opt; cb = () -> ())
|
||||
cb = runall(cb)
|
||||
opt = runall(opt)
|
||||
@progress for d in data
|
||||
try
|
||||
l = loss(d...)
|
||||
@interrupts back!(l)
|
||||
opt()
|
||||
foreach(x -> x.data .-= update!(opt, x.data, x.grad), ps)
|
||||
if cb() == :stop
|
||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||
break
|
||||
|
@ -83,6 +83,7 @@ function train!(loss, data, opt; cb = () -> ())
|
|||
end
|
||||
end
|
||||
end
|
||||
train!(model, loss, data, opt; cb = () -> ()) = train!(params(model), loss, data, opt; cb = cb)
|
||||
|
||||
"""
|
||||
@epochs N body
|
||||
|
|
|
@ -23,12 +23,12 @@ using Test
|
|||
end
|
||||
end
|
||||
|
||||
@testset "Compose" begin
|
||||
@testset "Optimiser" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [InvDecay, ExpDecay]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Compose(vec([Opt(), ADAM(0.001)]))
|
||||
opt = Optimiser(Opt(), ADAM(0.001))
|
||||
for t = 1:10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
|
|
Loading…
Reference in New Issue