Added test for ExpDecay
This commit is contained in:
parent
31a50ab16a
commit
fb3001b8b2
|
@ -53,3 +53,22 @@ end
|
|||
cbs()
|
||||
@test x == 1
|
||||
end
|
||||
|
||||
@testset "ExpDecay" begin
|
||||
w = randn(10, 10)
|
||||
o = ExpDecay(0.1, decay = 0.1, decay_step = 1000, clip = 1e-4)
|
||||
w1 = param(randn(10,10))
|
||||
loss(x) = Flux.mse(w*x, w1*x)
|
||||
flag = 1
|
||||
for t = 1:10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
prev_grad = collect(w1.grad)
|
||||
delta = Optimise.apply!(o, w1.data, w1.grad)
|
||||
array = fill(o.eta, size(prev_grad))
|
||||
if array .* prev_grad != delta
|
||||
flag = 0
|
||||
end
|
||||
end
|
||||
@test flag == 1
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue