Test for decay_step

This commit is contained in:
thebhatman 2019-05-01 23:10:00 +05:30
parent e459551336
commit 5e06d8bb76
1 changed files with 11 additions and 0 deletions

View File

@ -60,16 +60,27 @@ end
w1 = param(randn(10,10))
loss(x) = Flux.mse(w*x, w1*x)
flag = 1
step_flag = 1
decay_count = 0
for t = 1:10^5
l = loss(rand(10))
back!(l)
prev_eta = o.eta
prev_grad = collect(w1.grad)
delta = Optimise.apply!(o, w1.data, w1.grad)
w1.data .-= delta
new_eta = o.eta
if new_eta != prev_eta
decay_count += 1
if div(t, decay_count) != o.step
step_flag = 0
end
end
array = fill(o.eta, size(prev_grad))
if array .* prev_grad != delta
flag = 0
end
end
@test flag == 1
@test step_flag == 1
end