decay fixes
This commit is contained in:
parent
edbcd3c9ea
commit
815e8c206d
|
@ -5,7 +5,7 @@ function check_decay(opt, decay)
|
|||
opt = opt
|
||||
else
|
||||
if opt isa ADAMW
|
||||
opt = Optimiser(opt, DescentWeightDecay(1, decay))
|
||||
opt = Optimiser(opt, WeightDecay(decay))
|
||||
else
|
||||
opt = Optimiser(opt, InvDecay(decay))
|
||||
end
|
||||
|
@ -129,6 +129,10 @@ end
|
|||
|
||||
# Train function
|
||||
function train!(loss::Function, data, opt; cb = () -> ())
|
||||
depwarn("train!(loss, data, opt; cb) is deprecated; use train!(model, data, loss, opt; cb) instead", :train)
|
||||
train!(opt.ps, loss, data, opt.opt; cb = cb)
|
||||
end
|
||||
depwarn("train!(loss, data, opt; cb) is deprecated; use train!(loss, params, data, opt; cb) instead", :train)
|
||||
if fieldnames(typeof(opt)) !== ()
|
||||
train!(loss, opt.ps, data, opt.opt; cb = cb)
|
||||
else
|
||||
train!(loss, (), data, opt; cb = cb)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -258,38 +258,52 @@ end
|
|||
|
||||
mutable struct InvDecay
|
||||
gamma::Float64
|
||||
n::Int64
|
||||
state::IdDict
|
||||
end
|
||||
|
||||
InvDecay(γ = 0.001) = InvDecay(γ, 0)
|
||||
InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
|
||||
|
||||
function update!(o::InvDecay, x, Δ)
|
||||
γ, n = o.gamma, o.n
|
||||
γ = o.gamma
|
||||
n = get!(o.state, x, 1)
|
||||
Δ .*= 1 / (1 + γ * n)
|
||||
o.n += 1
|
||||
o.state[x] = n + 1
|
||||
return Δ
|
||||
end
|
||||
|
||||
mutable struct ExpDecay
|
||||
gamma::Float64
|
||||
opt
|
||||
decay::Float64
|
||||
step::Int64
|
||||
clip::Float64
|
||||
current::IdDict
|
||||
end
|
||||
|
||||
ExpDecay() = ExpDecay(0.001)
|
||||
ExpDecay(opt, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
|
||||
|
||||
function update!(o::ExpDecay, x, Δ)
|
||||
γ = o.gamma
|
||||
@. Δ += γ * x
|
||||
s, decay = o.step, o.decay
|
||||
η = try o.opt.eta; catch e; o.opt.rho; end
|
||||
n = o.current[x] = get(o.current, x, 0) + 1
|
||||
flag = false
|
||||
count(x -> x%s == 0, values(o.current)) == 1 && (flag = true)
|
||||
if o.current[x]%s == 0 && flag
|
||||
η = max(η * decay^(s / n), o.clip)
|
||||
o.opt isa ADADelta ? o.opt.rho = η : o.opt.eta = η
|
||||
end
|
||||
update!(o.opt, x, Δ)
|
||||
end
|
||||
|
||||
mutable struct WeightDecay
|
||||
eta::Real
|
||||
wd::Real
|
||||
end
|
||||
|
||||
WeightDecay(η = 1) = WeightDecay(η, 0)
|
||||
WeightDecay() = WeightDecay(0)
|
||||
function update!(o::WeightDecay, x, Δ)
|
||||
η, wd = o.eta, o.wd
|
||||
wd = o.wd
|
||||
@. Δ += wd * x
|
||||
end
|
||||
|
||||
DescentWeightDecay(η = 1, wd = 0) = Optimiser(WeightDecay(1, wd), Descent(η))
|
||||
DescentWeightDecay(η = 1, wd = 0) = Optimiser(WeightDecay(wd), Descent(η))
|
||||
|
||||
update!(opt::Function, ps) = opt()
|
||||
|
|
|
@ -4,9 +4,8 @@ import Base.depwarn
|
|||
|
||||
function update!(opt, xs)
|
||||
for x in xs
|
||||
x, Δ = data(x), grad(x)
|
||||
Δ = update!(opt, x, Δ)
|
||||
x .-= Δ
|
||||
Δ = update!(opt, x.data, x.grad)
|
||||
x.data .-= Δ
|
||||
Δ .= 0
|
||||
end
|
||||
end
|
||||
|
@ -62,14 +61,20 @@ The callback can return `:stop` to interrupt the training loop.
|
|||
|
||||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||
"""
|
||||
function train!(ps::Array, loss, data, opt; cb = () -> ())
|
||||
function train!(loss, ps, data, opt; cb = () -> ())
|
||||
cb = runall(cb)
|
||||
opt = runall(opt)
|
||||
opt = try
|
||||
opt()
|
||||
opt.opt
|
||||
catch
|
||||
opt
|
||||
end
|
||||
@progress for d in data
|
||||
try
|
||||
l = loss(d...)
|
||||
@interrupts back!(l)
|
||||
foreach(x -> x.data .-= update!(opt, x.data, x.grad), ps)
|
||||
update!(opt, ps)
|
||||
if cb() == :stop
|
||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||
break
|
||||
|
@ -83,7 +88,6 @@ function train!(ps::Array, loss, data, opt; cb = () -> ())
|
|||
end
|
||||
end
|
||||
end
|
||||
train!(model, loss, data, opt; cb = () -> ()) = train!(params(model), loss, data, opt; cb = cb)
|
||||
|
||||
"""
|
||||
@epochs N body
|
||||
|
|
|
@ -16,7 +16,7 @@ using Test
|
|||
for t = 1: 10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
delta = Optimise.update!(opt, w′.data, w′.grad)
|
||||
delta = Optimise.update!(opt, w′)
|
||||
w′.data .-= delta
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
|
@ -25,14 +25,16 @@ end
|
|||
|
||||
@testset "Optimiser" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [InvDecay, ExpDecay]
|
||||
@testset for Opt in [InvDecay, WeightDecay, ExpDecay]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Optimiser(Opt(), ADAM(0.001))
|
||||
if Opt isa ExpDecay
|
||||
opt = ExpDecay(ADAM(), 0.9)
|
||||
for t = 1:10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
delta = Optimise.update!(opt, w′.data, w′.grad)
|
||||
delta = Optimise.update!(opt, w′)
|
||||
w′.data .-= delta
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
|
@ -45,7 +47,7 @@ end
|
|||
|
||||
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||
Iterators.repeated((), 100),
|
||||
ADAM([l]),
|
||||
() -> (),
|
||||
cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
|
||||
|
||||
@test 3 < i < 50
|
||||
|
|
Loading…
Reference in New Issue