regression testing
This commit is contained in:
parent
f82dbf4798
commit
69cc5642b4
@ -2,18 +2,16 @@ using Flux.Optimise
|
|||||||
using Flux.Tracker
|
using Flux.Tracker
|
||||||
|
|
||||||
@testset "Optimise" begin
|
@testset "Optimise" begin
|
||||||
loss(x) = sum(x.^2)
|
w = randn(10, 10)
|
||||||
η = 0.1
|
for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta]
|
||||||
# RMSProp gets stuck
|
w′ = param(randn(10, 10))
|
||||||
for OPT in [SGD, Nesterov, Momentum, ADAM, ADAGrad, ADADelta]
|
loss(x) = Flux.mse(w*x, w′*x)
|
||||||
x = param(randn(10))
|
opt = Opt([w′])
|
||||||
opt = OPT == ADADelta ? OPT([x]) : OPT([x], η)
|
for t=1:10^5
|
||||||
for t=1:10000
|
l = loss(rand(10))
|
||||||
l = loss(x)
|
back!(l)
|
||||||
back!(l)
|
opt()
|
||||||
opt()
|
|
||||||
l.data[] < 1e-10 && break
|
|
||||||
end
|
|
||||||
@test loss(x) ≈ 0. atol=1e-7
|
|
||||||
end
|
end
|
||||||
|
@test Flux.mse(w, w′) < 0.01
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user