Comparing decay steps with expected true decay steps
This commit is contained in:
parent
5e06d8bb76
commit
5ffc3b2d40
|
@ -60,8 +60,7 @@ end
|
|||
w1 = param(randn(10,10))
|
||||
loss(x) = Flux.mse(w*x, w1*x)
|
||||
flag = 1
|
||||
step_flag = 1
|
||||
decay_count = 0
|
||||
decay_steps = []
|
||||
for t = 1:10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
|
@ -71,10 +70,7 @@ end
|
|||
w1.data .-= delta
|
||||
new_eta = o.eta
|
||||
if new_eta != prev_eta
|
||||
decay_count += 1
|
||||
if div(t, decay_count) != o.step
|
||||
step_flag = 0
|
||||
end
|
||||
push!(decay_steps, t)
|
||||
end
|
||||
array = fill(o.eta, size(prev_grad))
|
||||
if array .* prev_grad != delta
|
||||
|
@ -82,5 +78,11 @@ end
|
|||
end
|
||||
end
|
||||
@test flag == 1
|
||||
@test step_flag == 1
|
||||
# Test to check if decay happens at decay steps. Eta reaches clip value eventually.
|
||||
ground_truth = []
|
||||
for i in 1:11
|
||||
push!(ground_truth, 1000*i) # Expected decay steps for this example.
|
||||
end
|
||||
@test decay_steps == ground_truth
|
||||
@test o.eta == o.clip
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue