Flux.jl/test/optimise.jl

56 lines
1.4 KiB
Julia
Raw Normal View History

2017-10-12 08:31:38 +00:00
using Flux.Optimise
2018-11-08 13:14:57 +00:00
using Flux.Optimise: runall
2019-03-08 15:00:32 +00:00
using Zygote: Params, gradient
2018-08-11 12:54:59 +00:00
using Test
2019-03-08 15:00:32 +00:00
# @testset "Optimise" begin
# w = randn(10, 10)
# @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
# NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
# Momentum()]
# w = randn(10, 10)
# loss(x) = Flux.mse(w*x, w*x)
# for t = 1: 10^5
# θ = Params([w])
# θ̄ = gradient(() -> loss(rand(10)), θ)
# Optimise.update!(opt, θ, θ̄)
# end
# @test Flux.mse(w, w) < 0.01
# end
# end
2017-12-13 18:24:56 +00:00
2019-03-08 15:00:32 +00:00
# @testset "Optimiser" begin
# w = randn(10, 10)
# @testset for Opt in [InvDecay, WeightDecay, ExpDecay]
# w = param(randn(10, 10))
# loss(x) = Flux.mse(w*x, w*x)
# opt = Optimiser(Opt(), ADAM(0.001))
# for t = 1:10^5
# l = loss(rand(10))
# back!(l)
# delta = Optimise.apply!(opt, w.data, w.grad)
# w.data .-= delta
# end
# @test Flux.mse(w, w) < 0.01
# end
# end
2018-09-16 12:04:51 +00:00
2019-03-08 15:00:32 +00:00
# @testset "Training Loop" begin
# i = 0
# l = 1
#
# Flux.train!(() -> (sleep(0.1); i += 1; l),
# (),
# Iterators.repeated((), 100),
# Descent(),
# cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
#
# @test 3 < i < 50
#
# # Test multiple callbacks
# x = 0
# fs = [() -> (), () -> x = 1]
# cbs = runall(fs)
# cbs()
# @test x == 1
# end