Flux.jl/test/optimise.jl

54 lines
1.3 KiB
Julia
Raw Normal View History

2017-10-12 08:31:38 +00:00
using Flux.Optimise
using Flux.Tracker
2018-08-11 12:54:59 +00:00
using Test
2017-10-12 08:31:38 +00:00
@testset "Optimise" begin
2017-12-08 17:10:29 +00:00
w = randn(10, 10)
2018-09-16 12:04:51 +00:00
@testset for Opt in [ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
2018-09-14 15:02:56 +00:00
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt(0.001)
2018-09-16 12:04:51 +00:00
if opt isa Descent || opt isa ADAGrad
2018-09-11 13:00:24 +00:00
opt = Opt(0.1)
2018-09-14 15:02:56 +00:00
end
2018-09-16 12:04:51 +00:00
if opt isa ADADelta
opt = Opt(0.9)
end
2018-09-14 15:02:56 +00:00
for t = 1: 10^5
l = loss(rand(10))
back!(l)
2018-10-27 14:09:56 +00:00
delta = Optimise.update!(opt, w.data, w.grad)
2018-09-14 15:02:56 +00:00
w.data .-= delta
end
@test Flux.mse(w, w) < 0.01
2017-12-08 17:10:29 +00:00
end
2017-10-12 08:31:38 +00:00
end
2017-12-13 18:24:56 +00:00
2018-10-11 04:37:16 +00:00
@testset "Optimiser" begin
2018-09-16 12:04:51 +00:00
w = randn(10, 10)
2018-10-27 13:56:42 +00:00
@testset for Opt in [InvDecay, WeightDecay, ExpDecay]
2018-09-16 12:04:51 +00:00
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
2018-10-11 04:37:16 +00:00
opt = Optimiser(Opt(), ADAM(0.001))
2018-09-16 12:04:51 +00:00
for t = 1:10^5
l = loss(rand(10))
back!(l)
2018-10-27 14:23:06 +00:00
delta = Optimise.update!(opt, w.data, w.grad)
2018-09-16 12:04:51 +00:00
w.data .-= delta
end
@test Flux.mse(w, w) < 0.01
2018-09-16 12:15:29 +00:00
end
2018-09-16 12:04:51 +00:00
end
2017-12-13 18:24:56 +00:00
@testset "Training Loop" begin
i = 0
l = param(1)
Flux.train!(() -> (sleep(0.1); i += 1; l),
2018-10-31 14:58:55 +00:00
(),
2017-12-13 18:24:56 +00:00
Iterators.repeated((), 100),
2018-10-31 14:58:55 +00:00
Descent(),
2018-09-11 13:00:24 +00:00
cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
2017-12-13 18:24:56 +00:00
@test 3 < i < 50
end