decay fixes
This commit is contained in:
parent
edbcd3c9ea
commit
815e8c206d
@ -5,7 +5,7 @@ function check_decay(opt, decay)
|
|||||||
opt = opt
|
opt = opt
|
||||||
else
|
else
|
||||||
if opt isa ADAMW
|
if opt isa ADAMW
|
||||||
opt = Optimiser(opt, DescentWeightDecay(1, decay))
|
opt = Optimiser(opt, WeightDecay(decay))
|
||||||
else
|
else
|
||||||
opt = Optimiser(opt, InvDecay(decay))
|
opt = Optimiser(opt, InvDecay(decay))
|
||||||
end
|
end
|
||||||
@ -129,6 +129,10 @@ end
|
|||||||
|
|
||||||
# Train function
|
# Train function
|
||||||
function train!(loss::Function, data, opt; cb = () -> ())
|
function train!(loss::Function, data, opt; cb = () -> ())
|
||||||
depwarn("train!(loss, data, opt; cb) is deprecated; use train!(model, data, loss, opt; cb) instead", :train)
|
depwarn("train!(loss, data, opt; cb) is deprecated; use train!(loss, params, data, opt; cb) instead", :train)
|
||||||
train!(opt.ps, loss, data, opt.opt; cb = cb)
|
if fieldnames(typeof(opt)) !== ()
|
||||||
end
|
train!(loss, opt.ps, data, opt.opt; cb = cb)
|
||||||
|
else
|
||||||
|
train!(loss, (), data, opt; cb = cb)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
@ -258,38 +258,52 @@ end
|
|||||||
|
|
||||||
mutable struct InvDecay
|
mutable struct InvDecay
|
||||||
gamma::Float64
|
gamma::Float64
|
||||||
n::Int64
|
state::IdDict
|
||||||
end
|
end
|
||||||
|
|
||||||
InvDecay(γ = 0.001) = InvDecay(γ, 0)
|
InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
|
||||||
|
|
||||||
function update!(o::InvDecay, x, Δ)
|
function update!(o::InvDecay, x, Δ)
|
||||||
γ, n = o.gamma, o.n
|
γ = o.gamma
|
||||||
|
n = get!(o.state, x, 1)
|
||||||
Δ .*= 1 / (1 + γ * n)
|
Δ .*= 1 / (1 + γ * n)
|
||||||
o.n += 1
|
o.state[x] = n + 1
|
||||||
return Δ
|
return Δ
|
||||||
end
|
end
|
||||||
|
|
||||||
mutable struct ExpDecay
|
mutable struct ExpDecay
|
||||||
gamma::Float64
|
opt
|
||||||
|
decay::Float64
|
||||||
|
step::Int64
|
||||||
|
clip::Float64
|
||||||
|
current::IdDict
|
||||||
end
|
end
|
||||||
|
|
||||||
ExpDecay() = ExpDecay(0.001)
|
ExpDecay(opt, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
|
||||||
|
|
||||||
function update!(o::ExpDecay, x, Δ)
|
function update!(o::ExpDecay, x, Δ)
|
||||||
γ = o.gamma
|
s, decay = o.step, o.decay
|
||||||
@. Δ += γ * x
|
η = try o.opt.eta; catch e; o.opt.rho; end
|
||||||
|
n = o.current[x] = get(o.current, x, 0) + 1
|
||||||
|
flag = false
|
||||||
|
count(x -> x%s == 0, values(o.current)) == 1 && (flag = true)
|
||||||
|
if o.current[x]%s == 0 && flag
|
||||||
|
η = max(η * decay^(s / n), o.clip)
|
||||||
|
o.opt isa ADADelta ? o.opt.rho = η : o.opt.eta = η
|
||||||
|
end
|
||||||
|
update!(o.opt, x, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
mutable struct WeightDecay
|
mutable struct WeightDecay
|
||||||
eta::Real
|
|
||||||
wd::Real
|
wd::Real
|
||||||
end
|
end
|
||||||
|
|
||||||
WeightDecay(η = 1) = WeightDecay(η, 0)
|
WeightDecay() = WeightDecay(0)
|
||||||
function update!(o::WeightDecay, x, Δ)
|
function update!(o::WeightDecay, x, Δ)
|
||||||
η, wd = o.eta, o.wd
|
wd = o.wd
|
||||||
@. Δ += wd * x
|
@. Δ += wd * x
|
||||||
end
|
end
|
||||||
|
|
||||||
DescentWeightDecay(η = 1, wd = 0) = Optimiser(WeightDecay(1, wd), Descent(η))
|
DescentWeightDecay(η = 1, wd = 0) = Optimiser(WeightDecay(wd), Descent(η))
|
||||||
|
|
||||||
|
update!(opt::Function, ps) = opt()
|
||||||
|
@ -4,9 +4,8 @@ import Base.depwarn
|
|||||||
|
|
||||||
function update!(opt, xs)
|
function update!(opt, xs)
|
||||||
for x in xs
|
for x in xs
|
||||||
x, Δ = data(x), grad(x)
|
Δ = update!(opt, x.data, x.grad)
|
||||||
Δ = update!(opt, x, Δ)
|
x.data .-= Δ
|
||||||
x .-= Δ
|
|
||||||
Δ .= 0
|
Δ .= 0
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@ -62,14 +61,20 @@ The callback can return `:stop` to interrupt the training loop.
|
|||||||
|
|
||||||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||||
"""
|
"""
|
||||||
function train!(ps::Array, loss, data, opt; cb = () -> ())
|
function train!(loss, ps, data, opt; cb = () -> ())
|
||||||
cb = runall(cb)
|
cb = runall(cb)
|
||||||
opt = runall(opt)
|
opt = runall(opt)
|
||||||
|
opt = try
|
||||||
|
opt()
|
||||||
|
opt.opt
|
||||||
|
catch
|
||||||
|
opt
|
||||||
|
end
|
||||||
@progress for d in data
|
@progress for d in data
|
||||||
try
|
try
|
||||||
l = loss(d...)
|
l = loss(d...)
|
||||||
@interrupts back!(l)
|
@interrupts back!(l)
|
||||||
foreach(x -> x.data .-= update!(opt, x.data, x.grad), ps)
|
update!(opt, ps)
|
||||||
if cb() == :stop
|
if cb() == :stop
|
||||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||||
break
|
break
|
||||||
@ -83,7 +88,6 @@ function train!(ps::Array, loss, data, opt; cb = () -> ())
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
train!(model, loss, data, opt; cb = () -> ()) = train!(params(model), loss, data, opt; cb = cb)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@epochs N body
|
@epochs N body
|
||||||
|
@ -16,7 +16,7 @@ using Test
|
|||||||
for t = 1: 10^5
|
for t = 1: 10^5
|
||||||
l = loss(rand(10))
|
l = loss(rand(10))
|
||||||
back!(l)
|
back!(l)
|
||||||
delta = Optimise.update!(opt, w′.data, w′.grad)
|
delta = Optimise.update!(opt, w′)
|
||||||
w′.data .-= delta
|
w′.data .-= delta
|
||||||
end
|
end
|
||||||
@test Flux.mse(w, w′) < 0.01
|
@test Flux.mse(w, w′) < 0.01
|
||||||
@ -25,14 +25,16 @@ end
|
|||||||
|
|
||||||
@testset "Optimiser" begin
|
@testset "Optimiser" begin
|
||||||
w = randn(10, 10)
|
w = randn(10, 10)
|
||||||
@testset for Opt in [InvDecay, ExpDecay]
|
@testset for Opt in [InvDecay, WeightDecay, ExpDecay]
|
||||||
w′ = param(randn(10, 10))
|
w′ = param(randn(10, 10))
|
||||||
loss(x) = Flux.mse(w*x, w′*x)
|
loss(x) = Flux.mse(w*x, w′*x)
|
||||||
opt = Optimiser(Opt(), ADAM(0.001))
|
opt = Optimiser(Opt(), ADAM(0.001))
|
||||||
|
if Opt isa ExpDecay
|
||||||
|
opt = ExpDecay(ADAM(), 0.9)
|
||||||
for t = 1:10^5
|
for t = 1:10^5
|
||||||
l = loss(rand(10))
|
l = loss(rand(10))
|
||||||
back!(l)
|
back!(l)
|
||||||
delta = Optimise.update!(opt, w′.data, w′.grad)
|
delta = Optimise.update!(opt, w′)
|
||||||
w′.data .-= delta
|
w′.data .-= delta
|
||||||
end
|
end
|
||||||
@test Flux.mse(w, w′) < 0.01
|
@test Flux.mse(w, w′) < 0.01
|
||||||
@ -45,7 +47,7 @@ end
|
|||||||
|
|
||||||
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||||
Iterators.repeated((), 100),
|
Iterators.repeated((), 100),
|
||||||
ADAM([l]),
|
() -> (),
|
||||||
cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
|
cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
|
||||||
|
|
||||||
@test 3 < i < 50
|
@test 3 < i < 50
|
||||||
|
Loading…
Reference in New Issue
Block a user