From f64dca2df6508df30c03a2e1709dea61e0f11071 Mon Sep 17 00:00:00 2001 From: ylxdzsw Date: Mon, 26 Jun 2017 17:21:17 +0800 Subject: [PATCH] add test for optimizers --- test/optimizer.jl | 38 ++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 ++ 2 files changed, 40 insertions(+) create mode 100644 test/optimizer.jl diff --git a/test/optimizer.jl b/test/optimizer.jl new file mode 100644 index 00000000..57f1d011 --- /dev/null +++ b/test/optimizer.jl @@ -0,0 +1,38 @@ +@testset "training julia models" begin + + @testset "linear regression" begin + srand(0) + + model = Affine(10, 1) + + truth = Float32[0, 4, 2, 2, -3, 6, -1, 3, 2, -5]' + + data = map(1:256) do i + x = rand(Float32, 10) + x, truth * x + 3rand(Float32) + end + + Flux.train!(model, data, epoch=5) + + @test cor(reshape.((model.W.x, truth), 10)...) > .99 + end + + @testset "logistic regression" begin + srand(0) + + model = Chain(Affine(10, 1), σ) + + truth = Float32[0, 4, 2, 2, -3, 6, -1, 3, 2, -5]' + + data = map(1:256) do i + x = rand(Float32, 10) + x, truth * x + 2rand(Float32) > 5f0 + end + + Flux.train!(model, data, epoch=10) + + @test cor(reshape.((model.layers[1].W.x, truth), 10)...) > .99 + end + +end + diff --git a/test/runtests.jl b/test/runtests.jl index 8dd1dd8e..1e4981f7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,5 +15,7 @@ include("backend/common.jl") include("basic.jl") include("recurrent.jl") +include("optimizer.jl") + @tfonly include("backend/tensorflow.jl") @mxonly include("backend/mxnet.jl")