fixed ExpDecay test
This commit is contained in:
parent
ea508a79b0
commit
32ce2d78b8
|
@ -279,7 +279,7 @@ mutable struct ExpDecay
|
|||
current::IdDict
|
||||
end
|
||||
|
||||
ExpDecay(opt, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
|
||||
ExpDecay(opt = Descent(), decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
|
||||
|
||||
function update!(o::ExpDecay, x, Δ)
|
||||
s, decay = o.step, o.decay
|
||||
|
|
|
@ -30,11 +30,12 @@ end
|
|||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Optimiser(Opt(), ADAM(0.001))
|
||||
if Opt isa ExpDecay
|
||||
opt = ExpDecay(ADAM(), 0.9)
|
||||
opt = ExpDecay(ADAM(), 0.9, 1000)
|
||||
end
|
||||
for t = 1:10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
delta = Optimise.update!(opt, w′)
|
||||
delta = Optimise.update!(opt, w′.data, w′.grad)
|
||||
w′.data .-= delta
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
|
|
Loading…
Reference in New Issue