commit
221670a2b1
@ -307,7 +307,7 @@ function apply!(o::ExpDecay, x, Δ)
|
|||||||
η = max(η * decay^(s / n), o.clip)
|
η = max(η * decay^(s / n), o.clip)
|
||||||
o.eta = η
|
o.eta = η
|
||||||
end
|
end
|
||||||
@. Δ *= decay
|
@. Δ *= η
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -53,3 +53,36 @@ end
|
|||||||
cbs()
|
cbs()
|
||||||
@test x == 1
|
@test x == 1
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "ExpDecay" begin
|
||||||
|
w = randn(10, 10)
|
||||||
|
o = ExpDecay(0.1, 0.1, 1000, 1e-4)
|
||||||
|
w1 = param(randn(10,10))
|
||||||
|
loss(x) = Flux.mse(w*x, w1*x)
|
||||||
|
flag = 1
|
||||||
|
decay_steps = []
|
||||||
|
for t = 1:10^5
|
||||||
|
l = loss(rand(10))
|
||||||
|
back!(l)
|
||||||
|
prev_eta = o.eta
|
||||||
|
prev_grad = collect(w1.grad)
|
||||||
|
delta = Optimise.apply!(o, w1.data, w1.grad)
|
||||||
|
w1.data .-= delta
|
||||||
|
new_eta = o.eta
|
||||||
|
if new_eta != prev_eta
|
||||||
|
push!(decay_steps, t)
|
||||||
|
end
|
||||||
|
array = fill(o.eta, size(prev_grad))
|
||||||
|
if array .* prev_grad != delta
|
||||||
|
flag = 0
|
||||||
|
end
|
||||||
|
end
|
||||||
|
@test flag == 1
|
||||||
|
# Test to check if decay happens at decay steps. Eta reaches clip value eventually.
|
||||||
|
ground_truth = []
|
||||||
|
for i in 1:11
|
||||||
|
push!(ground_truth, 1000*i) # Expected decay steps for this example.
|
||||||
|
end
|
||||||
|
@test decay_steps == ground_truth
|
||||||
|
@test o.eta == o.clip
|
||||||
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user