2017-10-12 08:31:38 +00:00
|
|
|
|
using Flux.Optimise
|
2018-11-08 13:14:57 +00:00
|
|
|
|
using Flux.Optimise: runall
|
2017-10-12 08:31:38 +00:00
|
|
|
|
using Flux.Tracker
|
2018-08-11 12:54:59 +00:00
|
|
|
|
using Test
|
2017-10-12 08:31:38 +00:00
|
|
|
|
@testset "Optimise" begin
|
2017-12-08 17:10:29 +00:00
|
|
|
|
w = randn(10, 10)
|
2019-02-28 14:58:42 +00:00
|
|
|
|
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
|
|
|
|
|
NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
|
|
|
|
|
Momentum()]
|
2018-09-14 15:02:56 +00:00
|
|
|
|
w′ = param(randn(10, 10))
|
|
|
|
|
loss(x) = Flux.mse(w*x, w′*x)
|
|
|
|
|
for t = 1: 10^5
|
2019-02-28 14:58:42 +00:00
|
|
|
|
θ = Params([w′])
|
|
|
|
|
θ̄ = gradient(() -> loss(rand(10)), θ)
|
|
|
|
|
Optimise.update!(opt, θ, θ̄)
|
2018-09-14 15:02:56 +00:00
|
|
|
|
end
|
|
|
|
|
@test Flux.mse(w, w′) < 0.01
|
2017-12-08 17:10:29 +00:00
|
|
|
|
end
|
2017-10-12 08:31:38 +00:00
|
|
|
|
end
|
2017-12-13 18:24:56 +00:00
|
|
|
|
|
2018-10-11 04:37:16 +00:00
|
|
|
|
@testset "Optimiser" begin
|
2018-09-16 12:04:51 +00:00
|
|
|
|
w = randn(10, 10)
|
2018-10-27 13:56:42 +00:00
|
|
|
|
@testset for Opt in [InvDecay, WeightDecay, ExpDecay]
|
2018-09-16 12:04:51 +00:00
|
|
|
|
w′ = param(randn(10, 10))
|
|
|
|
|
loss(x) = Flux.mse(w*x, w′*x)
|
2018-10-11 04:37:16 +00:00
|
|
|
|
opt = Optimiser(Opt(), ADAM(0.001))
|
2018-09-16 12:04:51 +00:00
|
|
|
|
for t = 1:10^5
|
|
|
|
|
l = loss(rand(10))
|
|
|
|
|
back!(l)
|
2019-01-28 13:59:23 +00:00
|
|
|
|
delta = Optimise.apply!(opt, w′.data, w′.grad)
|
2018-09-16 12:04:51 +00:00
|
|
|
|
w′.data .-= delta
|
|
|
|
|
end
|
|
|
|
|
@test Flux.mse(w, w′) < 0.01
|
2018-09-16 12:15:29 +00:00
|
|
|
|
end
|
2018-09-16 12:04:51 +00:00
|
|
|
|
end
|
|
|
|
|
|
2017-12-13 18:24:56 +00:00
|
|
|
|
@testset "Training Loop" begin
|
|
|
|
|
i = 0
|
|
|
|
|
l = param(1)
|
|
|
|
|
|
|
|
|
|
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
2018-10-31 14:58:55 +00:00
|
|
|
|
(),
|
2017-12-13 18:24:56 +00:00
|
|
|
|
Iterators.repeated((), 100),
|
2018-10-31 14:58:55 +00:00
|
|
|
|
Descent(),
|
2018-09-11 13:00:24 +00:00
|
|
|
|
cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
|
2017-12-13 18:24:56 +00:00
|
|
|
|
|
|
|
|
|
@test 3 < i < 50
|
2018-11-08 13:14:57 +00:00
|
|
|
|
|
|
|
|
|
# Test multiple callbacks
|
|
|
|
|
x = 0
|
|
|
|
|
fs = [() -> (), () -> x = 1]
|
|
|
|
|
cbs = runall(fs)
|
|
|
|
|
cbs()
|
|
|
|
|
@test x == 1
|
2017-12-13 18:24:56 +00:00
|
|
|
|
end
|
2019-04-11 16:23:36 +00:00
|
|
|
|
|
|
|
|
|
@testset "ExpDecay" begin
|
|
|
|
|
w = randn(10, 10)
|
2019-04-11 16:29:50 +00:00
|
|
|
|
o = ExpDecay(0.1, 0.1, 1000, 1e-4)
|
2019-04-11 16:23:36 +00:00
|
|
|
|
w1 = param(randn(10,10))
|
|
|
|
|
loss(x) = Flux.mse(w*x, w1*x)
|
|
|
|
|
flag = 1
|
2019-05-01 18:42:14 +00:00
|
|
|
|
decay_steps = []
|
2019-04-11 16:23:36 +00:00
|
|
|
|
for t = 1:10^5
|
|
|
|
|
l = loss(rand(10))
|
|
|
|
|
back!(l)
|
2019-05-01 17:40:00 +00:00
|
|
|
|
prev_eta = o.eta
|
2019-04-11 16:23:36 +00:00
|
|
|
|
prev_grad = collect(w1.grad)
|
|
|
|
|
delta = Optimise.apply!(o, w1.data, w1.grad)
|
2019-04-11 16:29:50 +00:00
|
|
|
|
w1.data .-= delta
|
2019-05-01 17:40:00 +00:00
|
|
|
|
new_eta = o.eta
|
|
|
|
|
if new_eta != prev_eta
|
2019-05-01 18:42:14 +00:00
|
|
|
|
push!(decay_steps, t)
|
2019-05-01 17:40:00 +00:00
|
|
|
|
end
|
2019-04-11 16:23:36 +00:00
|
|
|
|
array = fill(o.eta, size(prev_grad))
|
|
|
|
|
if array .* prev_grad != delta
|
|
|
|
|
flag = 0
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
@test flag == 1
|
2019-05-01 18:42:14 +00:00
|
|
|
|
# Test to check if decay happens at decay steps. Eta reaches clip value eventually.
|
|
|
|
|
ground_truth = []
|
|
|
|
|
for i in 1:11
|
|
|
|
|
push!(ground_truth, 1000*i) # Expected decay steps for this example.
|
|
|
|
|
end
|
|
|
|
|
@test decay_steps == ground_truth
|
|
|
|
|
@test o.eta == o.clip
|
2019-04-11 16:23:36 +00:00
|
|
|
|
end
|