Align ExpDecay implementation with documentation

This commit is contained in:
DrChainsaw 2020-05-12 22:50:17 +02:00
parent de39d1095b
commit e8433d0abe
2 changed files with 14 additions and 3 deletions

View File

@ -509,7 +509,7 @@ function apply!(o::ExpDecay, x, Δ)
η, s, decay = o.eta, o.step, o.decay
n = o.current[x] = get(o.current, x, 0) + 1
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
η = max(η * decay^(s / n), o.clip)
η = max(η * decay, o.clip)
o.eta = η
end
@. Δ *= η

View File

@ -57,6 +57,17 @@ end
end
@testset "ExpDecay" begin
@testset "Sanity Check" begin
o = ExpDecay(0.2, 0.5, 1, 1e-3)
p = [0.0]
steps = 1:8
eta_expected = @. max(o.eta * 0.5 ^ steps, o.clip)
eta_actual = [Optimise.apply!(o, p, [1.0])[1] for _ in steps]
@test eta_actual == eta_expected
end
w = randn(10, 10)
o = ExpDecay(0.1, 0.1, 1000, 1e-4)
w1 = randn(10,10)
@ -81,9 +92,9 @@ end
end
end
@test flag == 1
# Test to check if decay happens at decay steps. Eta reaches clip value eventually.
# Test to check if decay happens at decay steps. Eta reaches clip value (1e-4) after 4000 steps (decay by 0.1 every 1000 steps starting at 0.1).
ground_truth = []
for i in 1:11
for i in 1:4
push!(ground_truth, 1000*i) # Expected decay steps for this example.
end
@test decay_steps == ground_truth