fixed ExpDecay update! rule
This commit is contained in:
parent
32ce2d78b8
commit
bebf4eb95f
|
@ -272,26 +272,25 @@ function update!(o::InvDecay, x, Δ)
|
|||
end
|
||||
|
||||
mutable struct ExpDecay
|
||||
opt
|
||||
eta::Float64
|
||||
decay::Float64
|
||||
step::Int64
|
||||
clip::Float64
|
||||
current::IdDict
|
||||
end
|
||||
|
||||
ExpDecay(opt = Descent(), decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
|
||||
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
|
||||
|
||||
function update!(o::ExpDecay, x, Δ)
|
||||
s, decay = o.step, o.decay
|
||||
η = try o.opt.eta; catch e; o.opt.rho; end
|
||||
η, s, decay = o.eta, o.step, o.decay
|
||||
n = o.current[x] = get(o.current, x, 0) + 1
|
||||
flag = false
|
||||
count(x -> x%s == 0, values(o.current)) == 1 && (flag = true)
|
||||
if o.current[x]%s == 0 && flag
|
||||
η = max(η * decay^(s / n), o.clip)
|
||||
o.opt isa ADADelta ? o.opt.rho = η : o.opt.eta = η
|
||||
o.eta = η
|
||||
end
|
||||
update!(o.opt, x, Δ)
|
||||
@. Δ *= decay
|
||||
end
|
||||
|
||||
mutable struct WeightDecay
|
||||
|
|
|
@ -29,9 +29,6 @@ end
|
|||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Optimiser(Opt(), ADAM(0.001))
|
||||
if Opt isa ExpDecay
|
||||
opt = ExpDecay(ADAM(), 0.9, 1000)
|
||||
end
|
||||
for t = 1:10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
|
|
Loading…
Reference in New Issue