2017-10-12 08:31:38 +00:00
|
|
|
|
using Flux.Optimise
|
|
|
|
|
using Flux.Tracker
|
|
|
|
|
|
|
|
|
|
@testset "Optimise" begin
|
2017-12-08 17:10:29 +00:00
|
|
|
|
w = randn(10, 10)
|
2017-12-08 18:24:07 +00:00
|
|
|
|
for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
|
2017-12-08 17:10:29 +00:00
|
|
|
|
w′ = param(randn(10, 10))
|
|
|
|
|
loss(x) = Flux.mse(w*x, w′*x)
|
|
|
|
|
opt = Opt([w′])
|
|
|
|
|
for t=1:10^5
|
|
|
|
|
l = loss(rand(10))
|
|
|
|
|
back!(l)
|
|
|
|
|
opt()
|
2017-10-12 08:31:38 +00:00
|
|
|
|
end
|
2017-12-08 17:10:29 +00:00
|
|
|
|
@test Flux.mse(w, w′) < 0.01
|
|
|
|
|
end
|
2017-10-12 08:31:38 +00:00
|
|
|
|
end
|