2017-10-12 08:31:38 +00:00
|
|
|
|
using Flux.Optimise
|
|
|
|
|
using Flux.Tracker
|
2018-08-11 12:54:59 +00:00
|
|
|
|
using Test
|
2017-10-12 08:31:38 +00:00
|
|
|
|
@testset "Optimise" begin
|
2017-12-08 17:10:29 +00:00
|
|
|
|
w = randn(10, 10)
|
2018-09-14 15:02:56 +00:00
|
|
|
|
@testset for Opt in [Descent, ADAM, Nesterov, RMSProp, Momentum]
|
|
|
|
|
w′ = param(randn(10, 10))
|
|
|
|
|
loss(x) = Flux.mse(w*x, w′*x)
|
|
|
|
|
opt = Opt(0.001)
|
|
|
|
|
if opt isa Descent
|
2018-09-11 13:00:24 +00:00
|
|
|
|
opt = Opt(0.1)
|
2018-09-14 15:02:56 +00:00
|
|
|
|
end
|
|
|
|
|
for t = 1: 10^5
|
|
|
|
|
l = loss(rand(10))
|
|
|
|
|
back!(l)
|
|
|
|
|
delta = Optimise.update!(opt, w′.data, w′.grad)
|
|
|
|
|
w′.data .-= delta
|
|
|
|
|
end
|
|
|
|
|
@test Flux.mse(w, w′) < 0.01
|
2017-12-08 17:10:29 +00:00
|
|
|
|
end
|
2017-10-12 08:31:38 +00:00
|
|
|
|
end
|
2017-12-13 18:24:56 +00:00
|
|
|
|
|
|
|
|
|
@testset "Training Loop" begin
|
|
|
|
|
i = 0
|
|
|
|
|
l = param(1)
|
|
|
|
|
|
|
|
|
|
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
|
|
|
|
Iterators.repeated((), 100),
|
|
|
|
|
()->(),
|
2018-09-11 13:00:24 +00:00
|
|
|
|
cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
|
2017-12-13 18:24:56 +00:00
|
|
|
|
|
|
|
|
|
@test 3 < i < 50
|
|
|
|
|
end
|