1177: Align ExpDecay implementation with documentation r=dhairyagandhi96 a=DrChainsaw

Fix for #1176 



Co-authored-by: DrChainsaw <Christian.kyril.skarby@gmail.com>
This commit is contained in:
bors[bot] 2020-05-21 14:33:20 +00:00 committed by GitHub
commit bd152ca099
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 30 deletions

View File

@ -509,7 +509,7 @@ function apply!(o::ExpDecay, x, Δ)
η, s, decay = o.eta, o.step, o.decay η, s, decay = o.eta, o.step, o.decay
n = o.current[x] = get(o.current, x, 0) + 1 n = o.current[x] = get(o.current, x, 0) + 1
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1 if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
η = max(η * decay^(s / n), o.clip) η = max(η * decay, o.clip)
o.eta = η o.eta = η
end end
@. Δ *= η @. Δ *= η

View File

@ -57,37 +57,47 @@ end
end end
@testset "ExpDecay" begin @testset "ExpDecay" begin
w = randn(10, 10)
o = ExpDecay(0.1, 0.1, 1000, 1e-4) @testset "Sanity Check" begin
w1 = randn(10,10) o = ExpDecay(0.2, 0.5, 1, 1e-3)
loss(x) = Flux.mse(w*x, w1*x) p = [0.0]
flag = 1 steps = 1:8
decay_steps = [] eta_expected = @. max(o.eta * 0.5 ^ steps, o.clip)
for t = 1:10^5 eta_actual = [Optimise.apply!(o, p, [1.0])[1] for _ in steps]
prev_eta = o.eta @test eta_actual == eta_expected
θ = Params([w1]) end
x = rand(10)
θ̄ = gradient(() -> loss(x), θ) w = randn(10, 10)
prev_grad = collect(θ̄[w1]) o = ExpDecay(0.1, 0.1, 1000, 1e-4)
delta = Optimise.apply!(o, w1, θ̄[w1]) w1 = randn(10,10)
w1 .-= delta loss(x) = Flux.mse(w*x, w1*x)
new_eta = o.eta flag = 1
if new_eta != prev_eta decay_steps = []
push!(decay_steps, t) for t = 1:10^5
end prev_eta = o.eta
array = fill(o.eta, size(prev_grad)) θ = Params([w1])
if array .* prev_grad != delta x = rand(10)
flag = 0 θ̄ = gradient(() -> loss(x), θ)
end prev_grad = collect(θ̄[w1])
delta = Optimise.apply!(o, w1, θ̄[w1])
w1 .-= delta
new_eta = o.eta
if new_eta != prev_eta
push!(decay_steps, t)
end end
@test flag == 1 array = fill(o.eta, size(prev_grad))
# Test to check if decay happens at decay steps. Eta reaches clip value eventually. if array .* prev_grad != delta
ground_truth = [] flag = 0
for i in 1:11
push!(ground_truth, 1000*i) # Expected decay steps for this example.
end end
@test decay_steps == ground_truth end
@test o.eta == o.clip @test flag == 1
# Test to check if decay happens at decay steps. Eta reaches clip value (1e-4) after 4000 steps (decay by 0.1 every 1000 steps starting at 0.1).
ground_truth = []
for i in 1:4
push!(ground_truth, 1000*i) # Expected decay steps for this example.
end
@test decay_steps == ground_truth
@test o.eta == o.clip
end end
@testset "Clipping" begin @testset "Clipping" begin