Test for decay_step
This commit is contained in:
parent
e459551336
commit
5e06d8bb76
@ -60,16 +60,27 @@ end
|
|||||||
w1 = param(randn(10,10))
|
w1 = param(randn(10,10))
|
||||||
loss(x) = Flux.mse(w*x, w1*x)
|
loss(x) = Flux.mse(w*x, w1*x)
|
||||||
flag = 1
|
flag = 1
|
||||||
|
step_flag = 1
|
||||||
|
decay_count = 0
|
||||||
for t = 1:10^5
|
for t = 1:10^5
|
||||||
l = loss(rand(10))
|
l = loss(rand(10))
|
||||||
back!(l)
|
back!(l)
|
||||||
|
prev_eta = o.eta
|
||||||
prev_grad = collect(w1.grad)
|
prev_grad = collect(w1.grad)
|
||||||
delta = Optimise.apply!(o, w1.data, w1.grad)
|
delta = Optimise.apply!(o, w1.data, w1.grad)
|
||||||
w1.data .-= delta
|
w1.data .-= delta
|
||||||
|
new_eta = o.eta
|
||||||
|
if new_eta != prev_eta
|
||||||
|
decay_count += 1
|
||||||
|
if div(t, decay_count) != o.step
|
||||||
|
step_flag = 0
|
||||||
|
end
|
||||||
|
end
|
||||||
array = fill(o.eta, size(prev_grad))
|
array = fill(o.eta, size(prev_grad))
|
||||||
if array .* prev_grad != delta
|
if array .* prev_grad != delta
|
||||||
flag = 0
|
flag = 0
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@test flag == 1
|
@test flag == 1
|
||||||
|
@test step_flag == 1
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user