Flux.jl/test/optimise.jl

18 lines
395 B
Julia
Raw Normal View History

2017-10-12 08:31:38 +00:00
using Flux.Optimise
using Flux.Tracker
@testset "Optimise" begin
2017-12-08 17:10:29 +00:00
w = randn(10, 10)
2017-12-08 18:24:07 +00:00
for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
2017-12-08 17:10:29 +00:00
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w])
for t=1:10^5
l = loss(rand(10))
back!(l)
opt()
2017-10-12 08:31:38 +00:00
end
2017-12-08 17:10:29 +00:00
@test Flux.mse(w, w) < 0.01
end
2017-10-12 08:31:38 +00:00
end