Decay checking test added back
This commit is contained in:
parent
4e9f3deb7f
commit
8292cfd81f
@ -69,17 +69,19 @@ end
|
|||||||
θ = Params([w1])
|
θ = Params([w1])
|
||||||
x = rand(10)
|
x = rand(10)
|
||||||
θ̄ = gradient(() -> loss(x), θ)
|
θ̄ = gradient(() -> loss(x), θ)
|
||||||
Optimise.update!(o, θ, θ̄)
|
prev_grad = collect(θ̄[w1])
|
||||||
|
delta = Optimise.apply!(o, w1, θ̄[w1])
|
||||||
|
w1 .-= delta
|
||||||
new_eta = o.eta
|
new_eta = o.eta
|
||||||
if new_eta != prev_eta
|
if new_eta != prev_eta
|
||||||
push!(decay_steps, t)
|
push!(decay_steps, t)
|
||||||
end
|
end
|
||||||
# array = fill(o.eta, size(prev_grad))
|
array = fill(o.eta, size(prev_grad))
|
||||||
# if array .* prev_grad != delta
|
if array .* prev_grad != delta
|
||||||
# flag = 0
|
flag = 0
|
||||||
# end
|
end
|
||||||
end
|
end
|
||||||
#@test flag == 1
|
@test flag == 1
|
||||||
# Test to check if decay happens at decay steps. Eta reaches clip value eventually.
|
# Test to check if decay happens at decay steps. Eta reaches clip value eventually.
|
||||||
ground_truth = []
|
ground_truth = []
|
||||||
for i in 1:11
|
for i in 1:11
|
||||||
|
Loading…
Reference in New Issue
Block a user