test restructure

This commit is contained in:
Mike J Innes 2017-02-23 22:28:18 +00:00
parent 06fd5adddc
commit 498fa9dd4c
4 changed files with 57 additions and 49 deletions

View File

@ -1,48 +0,0 @@
xs = rand(20)
d = Affine(20, 10)
# MXNet
@mxonly let dm = mxnet(d, (20, 1))
@test d(xs) dm(xs)
end
@mxonly let
# TODO: test run
using MXNet
f = mx.FeedForward(Chain(d, softmax))
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
flatten, Affine(1587, 10), softmax)
f = mx.FeedForward(m)
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
end
@mxonly let
model = TLP(Affine(10, 20), Affine(21, 15))
info("The following warning is normal")
e = try mxnet(model, (10, 1))
catch e e end
@test isa(e, DataFlow.Interpreter.Exception)
@test e.trace[1].func == Symbol("Flux.Affine")
@test e.trace[2].func == :TLP
end
# TensorFlow
@tfonly let dt = tf(d)
@test d(xs) dt(xs)
end
@tfonly let
using TensorFlow
sess = TensorFlow.Session()
X = placeholder(Float32)
Y = Tensor(d, X)
run(sess, initialize_all_variables())
@test run(sess, Y, Dict(X=>xs')) d(xs)'
end

34
test/backend/mxnet.jl Normal file
View File

@ -0,0 +1,34 @@
using MXNet
Flux.loadmx()
@testset "MXNet" begin
xs = rand(20)
d = Affine(20, 10)
dm = mxnet(d, (20, 1))
@test d(xs) dm(xs)
@testset "FeedForward interface" begin
# TODO: test run
f = mx.FeedForward(Chain(d, softmax))
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
flatten, Affine(1587, 10), softmax)
f = mx.FeedForward(m)
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
end
@testset "Stack Traces" begin
model = TLP(Affine(10, 20), Affine(21, 15))
info("The following warning is normal")
e = try mxnet(model, (10, 1))
catch e e end
@test isa(e, DataFlow.Interpreter.Exception)
@test e.trace[1].func == Symbol("Flux.Affine")
@test e.trace[2].func == :TLP
end
end

View File

@ -0,0 +1,21 @@
using TensorFlow
Flux.loadtf()
@testset "TensorFlow" begin
xs = rand(20)
d = Affine(20, 10)
dt = tf(d)
@test d(xs) dt(xs)
@testset "Tensor interface" begin
sess = TensorFlow.Session()
X = placeholder(Float32)
Y = Tensor(d, X)
run(sess, initialize_all_variables())
@test run(sess, Y, Dict(X=>xs')) d(xs)'
end
end

View File

@ -16,4 +16,5 @@ end
include("batching.jl")
include("basic.jl")
include("recurrent.jl")
include("backend.jl")
@tfonly include("backend/tensorflow.jl")
@mxonly include("backend/mxnet.jl")