test restructure
This commit is contained in:
parent
06fd5adddc
commit
498fa9dd4c
|
@ -1,48 +0,0 @@
|
|||
xs = rand(20)
|
||||
d = Affine(20, 10)
|
||||
|
||||
# MXNet
|
||||
|
||||
@mxonly let dm = mxnet(d, (20, 1))
|
||||
@test d(xs) ≈ dm(xs)
|
||||
end
|
||||
|
||||
@mxonly let
|
||||
# TODO: test run
|
||||
using MXNet
|
||||
f = mx.FeedForward(Chain(d, softmax))
|
||||
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
|
||||
|
||||
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
|
||||
flatten, Affine(1587, 10), softmax)
|
||||
f = mx.FeedForward(m)
|
||||
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
|
||||
end
|
||||
|
||||
@mxonly let
|
||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
||||
info("The following warning is normal")
|
||||
e = try mxnet(model, (10, 1))
|
||||
catch e e end
|
||||
|
||||
@test isa(e, DataFlow.Interpreter.Exception)
|
||||
@test e.trace[1].func == Symbol("Flux.Affine")
|
||||
@test e.trace[2].func == :TLP
|
||||
end
|
||||
|
||||
# TensorFlow
|
||||
|
||||
@tfonly let dt = tf(d)
|
||||
@test d(xs) ≈ dt(xs)
|
||||
end
|
||||
|
||||
@tfonly let
|
||||
using TensorFlow
|
||||
|
||||
sess = TensorFlow.Session()
|
||||
X = placeholder(Float32)
|
||||
Y = Tensor(d, X)
|
||||
run(sess, initialize_all_variables())
|
||||
|
||||
@test run(sess, Y, Dict(X=>xs')) ≈ d(xs)'
|
||||
end
|
|
@ -0,0 +1,34 @@
|
|||
using MXNet
|
||||
Flux.loadmx()
|
||||
|
||||
@testset "MXNet" begin
|
||||
|
||||
xs = rand(20)
|
||||
d = Affine(20, 10)
|
||||
|
||||
dm = mxnet(d, (20, 1))
|
||||
@test d(xs) ≈ dm(xs)
|
||||
|
||||
@testset "FeedForward interface" begin
|
||||
# TODO: test run
|
||||
f = mx.FeedForward(Chain(d, softmax))
|
||||
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
|
||||
|
||||
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
|
||||
flatten, Affine(1587, 10), softmax)
|
||||
f = mx.FeedForward(m)
|
||||
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
|
||||
end
|
||||
|
||||
@testset "Stack Traces" begin
|
||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
||||
info("The following warning is normal")
|
||||
e = try mxnet(model, (10, 1))
|
||||
catch e e end
|
||||
|
||||
@test isa(e, DataFlow.Interpreter.Exception)
|
||||
@test e.trace[1].func == Symbol("Flux.Affine")
|
||||
@test e.trace[2].func == :TLP
|
||||
end
|
||||
|
||||
end
|
|
@ -0,0 +1,21 @@
|
|||
using TensorFlow
|
||||
Flux.loadtf()
|
||||
|
||||
@testset "TensorFlow" begin
|
||||
|
||||
xs = rand(20)
|
||||
d = Affine(20, 10)
|
||||
|
||||
dt = tf(d)
|
||||
@test d(xs) ≈ dt(xs)
|
||||
|
||||
@testset "Tensor interface" begin
|
||||
sess = TensorFlow.Session()
|
||||
X = placeholder(Float32)
|
||||
Y = Tensor(d, X)
|
||||
run(sess, initialize_all_variables())
|
||||
|
||||
@test run(sess, Y, Dict(X=>xs')) ≈ d(xs)'
|
||||
end
|
||||
|
||||
end
|
|
@ -16,4 +16,5 @@ end
|
|||
include("batching.jl")
|
||||
include("basic.jl")
|
||||
include("recurrent.jl")
|
||||
include("backend.jl")
|
||||
@tfonly include("backend/tensorflow.jl")
|
||||
@mxonly include("backend/mxnet.jl")
|
||||
|
|
Loading…
Reference in New Issue