decay fixes

This commit is contained in:
Dhairya Gandhi 2018-10-27 19:26:42 +05:30
parent edbcd3c9ea
commit 815e8c206d
4 changed files with 50 additions and 26 deletions

View File

@ -5,7 +5,7 @@ function check_decay(opt, decay)
opt = opt
else
if opt isa ADAMW
opt = Optimiser(opt, DescentWeightDecay(1, decay))
opt = Optimiser(opt, WeightDecay(decay))
else
opt = Optimiser(opt, InvDecay(decay))
end
@ -129,6 +129,10 @@ end
# Train function
function train!(loss::Function, data, opt; cb = () -> ())
depwarn("train!(loss, data, opt; cb) is deprecated; use train!(model, data, loss, opt; cb) instead", :train)
train!(opt.ps, loss, data, opt.opt; cb = cb)
end
depwarn("train!(loss, data, opt; cb) is deprecated; use train!(loss, params, data, opt; cb) instead", :train)
if fieldnames(typeof(opt)) !== ()
train!(loss, opt.ps, data, opt.opt; cb = cb)
else
train!(loss, (), data, opt; cb = cb)
end
end

View File

@ -258,38 +258,52 @@ end
mutable struct InvDecay
gamma::Float64
n::Int64
state::IdDict
end
InvDecay(γ = 0.001) = InvDecay(γ, 0)
InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
function update!(o::InvDecay, x, Δ)
γ, n = o.gamma, o.n
γ = o.gamma
n = get!(o.state, x, 1)
Δ .*= 1 / (1 + γ * n)
o.n += 1
o.state[x] = n + 1
return Δ
end
mutable struct ExpDecay
gamma::Float64
opt
decay::Float64
step::Int64
clip::Float64
current::IdDict
end
ExpDecay() = ExpDecay(0.001)
ExpDecay(opt, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
function update!(o::ExpDecay, x, Δ)
γ = o.gamma
@. Δ += γ * x
s, decay = o.step, o.decay
η = try o.opt.eta; catch e; o.opt.rho; end
n = o.current[x] = get(o.current, x, 0) + 1
flag = false
count(x -> x%s == 0, values(o.current)) == 1 && (flag = true)
if o.current[x]%s == 0 && flag
η = max(η * decay^(s / n), o.clip)
o.opt isa ADADelta ? o.opt.rho = η : o.opt.eta = η
end
update!(o.opt, x, Δ)
end
mutable struct WeightDecay
eta::Real
wd::Real
end
WeightDecay(η = 1) = WeightDecay(η, 0)
WeightDecay() = WeightDecay(0)
function update!(o::WeightDecay, x, Δ)
η, wd = o.eta, o.wd
wd = o.wd
@. Δ += wd * x
end
DescentWeightDecay(η = 1, wd = 0) = Optimiser(WeightDecay(1, wd), Descent(η))
DescentWeightDecay(η = 1, wd = 0) = Optimiser(WeightDecay(wd), Descent(η))
update!(opt::Function, ps) = opt()

View File

@ -4,9 +4,8 @@ import Base.depwarn
function update!(opt, xs)
for x in xs
x, Δ = data(x), grad(x)
Δ = update!(opt, x, Δ)
x .-= Δ
Δ = update!(opt, x.data, x.grad)
x.data .-= Δ
Δ .= 0
end
end
@ -62,14 +61,20 @@ The callback can return `:stop` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
"""
function train!(ps::Array, loss, data, opt; cb = () -> ())
function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb)
opt = runall(opt)
opt = try
opt()
opt.opt
catch
opt
end
@progress for d in data
try
l = loss(d...)
@interrupts back!(l)
foreach(x -> x.data .-= update!(opt, x.data, x.grad), ps)
update!(opt, ps)
if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break
@ -83,7 +88,6 @@ function train!(ps::Array, loss, data, opt; cb = () -> ())
end
end
end
train!(model, loss, data, opt; cb = () -> ()) = train!(params(model), loss, data, opt; cb = cb)
"""
@epochs N body

View File

@ -16,7 +16,7 @@ using Test
for t = 1: 10^5
l = loss(rand(10))
back!(l)
delta = Optimise.update!(opt, w.data, w.grad)
delta = Optimise.update!(opt, w)
w.data .-= delta
end
@test Flux.mse(w, w) < 0.01
@ -25,14 +25,16 @@ end
@testset "Optimiser" begin
w = randn(10, 10)
@testset for Opt in [InvDecay, ExpDecay]
@testset for Opt in [InvDecay, WeightDecay, ExpDecay]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Optimiser(Opt(), ADAM(0.001))
if Opt isa ExpDecay
opt = ExpDecay(ADAM(), 0.9)
for t = 1:10^5
l = loss(rand(10))
back!(l)
delta = Optimise.update!(opt, w.data, w.grad)
delta = Optimise.update!(opt, w)
w.data .-= delta
end
@test Flux.mse(w, w) < 0.01
@ -45,7 +47,7 @@ end
Flux.train!(() -> (sleep(0.1); i += 1; l),
Iterators.repeated((), 100),
ADAM([l]),
() -> (),
cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
@test 3 < i < 50