weights updated in tests
This commit is contained in:
parent
fb3001b8b2
commit
e459551336
@ -56,7 +56,7 @@ end
|
|||||||
|
|
||||||
@testset "ExpDecay" begin
|
@testset "ExpDecay" begin
|
||||||
w = randn(10, 10)
|
w = randn(10, 10)
|
||||||
o = ExpDecay(0.1, decay = 0.1, decay_step = 1000, clip = 1e-4)
|
o = ExpDecay(0.1, 0.1, 1000, 1e-4)
|
||||||
w1 = param(randn(10,10))
|
w1 = param(randn(10,10))
|
||||||
loss(x) = Flux.mse(w*x, w1*x)
|
loss(x) = Flux.mse(w*x, w1*x)
|
||||||
flag = 1
|
flag = 1
|
||||||
@ -65,6 +65,7 @@ end
|
|||||||
back!(l)
|
back!(l)
|
||||||
prev_grad = collect(w1.grad)
|
prev_grad = collect(w1.grad)
|
||||||
delta = Optimise.apply!(o, w1.data, w1.grad)
|
delta = Optimise.apply!(o, w1.data, w1.grad)
|
||||||
|
w1.data .-= delta
|
||||||
array = fill(o.eta, size(prev_grad))
|
array = fill(o.eta, size(prev_grad))
|
||||||
if array .* prev_grad != delta
|
if array .* prev_grad != delta
|
||||||
flag = 0
|
flag = 0
|
||||||
|
Loading…
Reference in New Issue
Block a user