Tests for Optimisers supporting Zygote

This commit is contained in:
thebhatman 2019-06-06 04:09:17 +05:30
parent fecb6bd16f
commit 0ddb5f0265

View File

@ -2,87 +2,87 @@ using Flux.Optimise
using Flux.Optimise: runall using Flux.Optimise: runall
using Zygote: Params, gradient using Zygote: Params, gradient
using Test using Test
# @testset "Optimise" begin @testset "Optimise" begin
# w = randn(10, 10) w = randn(10, 10)
# @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(), @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
# NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(), NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
# Momentum()] Momentum()]
# w = randn(10, 10) w = randn(10, 10)
# loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)
# for t = 1: 10^5 for t = 1: 10^5
# θ = Params([w]) θ = Params([w])
# θ̄ = gradient(() -> loss(rand(10)), θ) x = rand(10)
# Optimise.update!(opt, θ, θ̄) θ̄ = gradient(() -> loss(x), θ)
# end Optimise.update!(opt, θ, θ̄)
# @test Flux.mse(w, w) < 0.01 end
# end @test loss(rand(10, 10)) < 0.01
# end end
end
# @testset "Optimiser" begin @testset "Optimiser" begin
# w = randn(10, 10) w = randn(10, 10)
# @testset for Opt in [InvDecay, WeightDecay, ExpDecay] @testset for Opt in [InvDecay, WeightDecay, ExpDecay]
# w = param(randn(10, 10)) w = randn(10, 10)
# loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)
# opt = Optimiser(Opt(), ADAM(0.001)) opt = Optimiser(Opt(), ADAM(0.001))
# for t = 1:10^5 for t = 1:10^5
# l = loss(rand(10)) θ = Params([w])
# back!(l) x = rand(10)
# delta = Optimise.apply!(opt, w.data, w.grad) θ̄ = gradient(() -> loss(x), θ)
# w.data .-= delta Optimise.update!(opt, θ, θ̄)
# end end
# @test Flux.mse(w, w) < 0.01 @test loss(rand(10, 10)) < 0.01
# end end
# end end
# @testset "Training Loop" begin @testset "Training Loop" begin
# i = 0 i = 0
# l = 1 l = 1
#
# Flux.train!(() -> (sleep(0.1); i += 1; l), Flux.train!(() -> (sleep(0.1); i += 1; l),
# (), (),
# Iterators.repeated((), 100), Iterators.repeated((), 100),
# Descent(), Descent(),
# cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1)) cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
#
# @test 3 < i < 50 @test 3 < i < 50
#
# # Test multiple callbacks # Test multiple callbacks
# x = 0 x = 0
# fs = [() -> (), () -> x = 1] fs = [() -> (), () -> x = 1]
# cbs = runall(fs) cbs = runall(fs)
# cbs() cbs()
# @test x == 1 @test x == 1
# end end
#
# @testset "ExpDecay" begin @testset "ExpDecay" begin
# w = randn(10, 10) w = randn(10, 10)
# o = ExpDecay(0.1, 0.1, 1000, 1e-4) o = ExpDecay(0.1, 0.1, 1000, 1e-4)
# w1 = param(randn(10,10)) w1 = randn(10,10)
# loss(x) = Flux.mse(w*x, w1*x) loss(x) = Flux.mse(w*x, w1*x)
# flag = 1 flag = 1
# decay_steps = [] decay_steps = []
# for t = 1:10^5 for t = 1:10^5
# l = loss(rand(10)) prev_eta = o.eta
# back!(l) θ = Params([w1])
# prev_eta = o.eta x = rand(10)
# prev_grad = collect(w1.grad) θ̄ = gradient(() -> loss(x), θ)
# delta = Optimise.apply!(o, w1.data, w1.grad) Optimise.update!(o, θ, θ̄)
# w1.data .-= delta new_eta = o.eta
# new_eta = o.eta if new_eta != prev_eta
# if new_eta != prev_eta push!(decay_steps, t)
# push!(decay_steps, t) end
# end # array = fill(o.eta, size(prev_grad))
# array = fill(o.eta, size(prev_grad)) # if array .* prev_grad != delta
# if array .* prev_grad != delta # flag = 0
# flag = 0 # end
# end end
# end #@test flag == 1
# @test flag == 1 # Test to check if decay happens at decay steps. Eta reaches clip value eventually.
# # Test to check if decay happens at decay steps. Eta reaches clip value eventually. ground_truth = []
# ground_truth = [] for i in 1:11
# for i in 1:11 push!(ground_truth, 1000*i) # Expected decay steps for this example.
# push!(ground_truth, 1000*i) # Expected decay steps for this example. end
# end @test decay_steps == ground_truth
# @test decay_steps == ground_truth @test o.eta == o.clip
# @test o.eta == o.clip end
# end