Added test for ExpDecay

This commit is contained in:
thebhatman 2019-04-11 21:53:36 +05:30
parent 31a50ab16a
commit fb3001b8b2
1 changed files with 19 additions and 0 deletions

View File

@ -53,3 +53,22 @@ end
cbs()
@test x == 1
end
@testset "ExpDecay" begin
w = randn(10, 10)
o = ExpDecay(0.1, decay = 0.1, decay_step = 1000, clip = 1e-4)
w1 = param(randn(10,10))
loss(x) = Flux.mse(w*x, w1*x)
flag = 1
for t = 1:10^5
l = loss(rand(10))
back!(l)
prev_grad = collect(w1.grad)
delta = Optimise.apply!(o, w1.data, w1.grad)
array = fill(o.eta, size(prev_grad))
if array .* prev_grad != delta
flag = 0
end
end
@test flag == 1
end