fixed ExpDecay update! rule

This commit is contained in:
Dhairya Gandhi 2018-10-29 23:12:24 +05:30
parent 32ce2d78b8
commit bebf4eb95f
2 changed files with 5 additions and 9 deletions

View File

@ -272,26 +272,25 @@ function update!(o::InvDecay, x, Δ)
end
mutable struct ExpDecay
opt
eta::Float64
decay::Float64
step::Int64
clip::Float64
current::IdDict
end
ExpDecay(opt = Descent(), decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
function update!(o::ExpDecay, x, Δ)
s, decay = o.step, o.decay
η = try o.opt.eta; catch e; o.opt.rho; end
η, s, decay = o.eta, o.step, o.decay
n = o.current[x] = get(o.current, x, 0) + 1
flag = false
count(x -> x%s == 0, values(o.current)) == 1 && (flag = true)
if o.current[x]%s == 0 && flag
η = max(η * decay^(s / n), o.clip)
o.opt isa ADADelta ? o.opt.rho = η : o.opt.eta = η
o.eta = η
end
update!(o.opt, x, Δ)
@. Δ *= decay
end
mutable struct WeightDecay

View File

@ -29,9 +29,6 @@ end
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Optimiser(Opt(), ADAM(0.001))
if Opt isa ExpDecay
opt = ExpDecay(ADAM(), 0.9, 1000)
end
for t = 1:10^5
l = loss(rand(10))
back!(l)