From e8433d0abe492bb3df192fb89651bb25b1aaf49f Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Tue, 12 May 2020 22:50:17 +0200 Subject: [PATCH] Align ExpDecay implementation with documentation --- src/optimise/optimisers.jl | 2 +- test/optimise.jl | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 611edddb..7ede1e72 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -509,7 +509,7 @@ function apply!(o::ExpDecay, x, Δ) η, s, decay = o.eta, o.step, o.decay n = o.current[x] = get(o.current, x, 0) + 1 if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1 - η = max(η * decay^(s / n), o.clip) + η = max(η * decay, o.clip) o.eta = η end @. Δ *= η diff --git a/test/optimise.jl b/test/optimise.jl index ac131b96..5ccc897e 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -57,6 +57,17 @@ end end @testset "ExpDecay" begin + + @testset "Sanity Check" begin + o = ExpDecay(0.2, 0.5, 1, 1e-3) + p = [0.0] + steps = 1:8 + eta_expected = @. max(o.eta * 0.5 ^ steps, o.clip) + eta_actual = [Optimise.apply!(o, p, [1.0])[1] for _ in steps] + @test eta_actual == eta_expected + end + + w = randn(10, 10) o = ExpDecay(0.1, 0.1, 1000, 1e-4) w1 = randn(10,10) @@ -81,9 +92,9 @@ end end end @test flag == 1 - # Test to check if decay happens at decay steps. Eta reaches clip value eventually. + # Test to check if decay happens at decay steps. Eta reaches clip value (1e-4) after 4000 steps (decay by 0.1 every 1000 steps starting at 0.1). ground_truth = [] - for i in 1:11 + for i in 1:4 push!(ground_truth, 1000*i) # Expected decay steps for this example. end @test decay_steps == ground_truth