Merge pull request #733 from thebhatman/expdecay-fix

Fixed ExpDecay
This commit is contained in:
Dhairya Gandhi 2019-05-01 18:58:37 +05:30 committed by GitHub
commit 221670a2b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 1 deletions

View File

@ -307,7 +307,7 @@ function apply!(o::ExpDecay, x, Δ)
η = max(η * decay^(s / n), o.clip)
o.eta = η
end
@. Δ *= decay
@. Δ *= η
end
"""

View File

@ -53,3 +53,36 @@ end
cbs()
@test x == 1
end
@testset "ExpDecay" begin
w = randn(10, 10)
o = ExpDecay(0.1, 0.1, 1000, 1e-4)
w1 = param(randn(10,10))
loss(x) = Flux.mse(w*x, w1*x)
flag = 1
decay_steps = []
for t = 1:10^5
l = loss(rand(10))
back!(l)
prev_eta = o.eta
prev_grad = collect(w1.grad)
delta = Optimise.apply!(o, w1.data, w1.grad)
w1.data .-= delta
new_eta = o.eta
if new_eta != prev_eta
push!(decay_steps, t)
end
array = fill(o.eta, size(prev_grad))
if array .* prev_grad != delta
flag = 0
end
end
@test flag == 1
# Test to check if decay happens at decay steps. Eta reaches clip value eventually.
ground_truth = []
for i in 1:11
push!(ground_truth, 1000*i) # Expected decay steps for this example.
end
@test decay_steps == ground_truth
@test o.eta == o.clip
end