factor out common tests

This commit is contained in:
Mike J Innes 2017-05-04 13:52:31 +01:00
parent 70286c0bf3
commit a2db4b5319
6 changed files with 70 additions and 70 deletions

View File

@ -18,6 +18,7 @@ node(x::Tuple) = map(node, x)
node(x::mx.SymbolicNode) = x
graph(::typeof(tuple), args...) = (args...,)
graph(::typeof(identity), x) = x
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
graph(::typeof(.+), args...) = mx.broadcast_plus(args...)
graph(::typeof(.*), args...) = mx.broadcast_mul(args...)

57
test/backend/common.jl Normal file
View File

@ -0,0 +1,57 @@
@net type TLP
first
second
function (x)
l1 = σ(first(x))
l2 = softmax(second(l1))
end
end
function test_tupleio(bk)
@testset "Tuple I/O" begin
val = [1,2,3]
tup = ([1,2,3],[4,5,6])
@test bk(@net x -> (identity(x),))(val) == (val,)
@test bk(@net x -> x[1].*x[2])(tup) == [4,10,18]
end
end
function test_recurrence(bk)
@testset "Recurrence" begin
seq = batchone(Seq(rand(10) for i = 1:3))
r = unroll(Recurrent(10, 5), 3)
rm = bk(r)
@test r(seq) rm(seq)
end
end
function test_back(bk)
@testset "Backward Pass" begin
xs, ys = rand(1, 20), rand(1, 20)
d = Affine(20, 10)
dm = bk(d)
d = deepcopy(d)
@test dm(xs) d(xs)
@test dm(xs) d(xs)
Δ = back!(dm, randn(1, 10), xs)
@test length(Δ[1]) == 20
update!(dm, 0.1)
@test dm(xs) d(xs)
@test dm(xs) d(xs)
end
end
function test_stacktrace(bk)
@testset "Stack Traces" begin
model = TLP(Affine(10, 20), Affine(21, 15))
dm = bk(model)
e = try dm(rand(1, 10))
catch e e end
@test isa(e, DataFlow.Interpreter.Exception)
@test e.trace[1].func == Symbol("Flux.Affine")
@test e.trace[2].func == :TLP
end
end

View File

@ -9,30 +9,10 @@ d = Affine(20, 10)
dm = mxnet(d)
@test d(xs) dm(xs)
@testset "Tuple I/O" begin
@test mxnet(@net x -> (x,))([1,2,3]) == ([1,2,3],)
@test mxnet(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18]
end
@testset "Recurrence" begin
seq = batchone(Seq(rand(10) for i = 1:3))
r = unroll(Recurrent(10, 5), 3)
rm = mxnet(r)
@test r(seq) rm(seq)
end
@testset "Backward Pass" begin
d = deepcopy(d)
@test dm(xs) d(xs)
@test dm(xs) d(xs)
Δ = back!(dm, randn(1, 10), xs)
@test length(Δ[1]) == 20
update!(dm, 0.1)
@test dm(xs) d(xs)
@test dm(xs) d(xs)
end
test_tupleio(mxnet)
test_recurrence(mxnet)
test_stacktrace(mxnet)
test_back(mxnet)
@testset "Native interface" begin
f = mx.FeedForward(Chain(d, softmax))
@ -45,16 +25,4 @@ end
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
end
@testset "Stack Traces" begin
model = TLP(Affine(10, 20), Affine(21, 15))
info("The following warning is normal")
dm = mxnet(model)
e = try dm(rand(1, 10))
catch e e end
@test isa(e, DataFlow.Interpreter.Exception)
@test e.trace[1].func == Symbol("Flux.Affine")
@test e.trace[2].func == :TLP
end
end

View File

@ -9,17 +9,9 @@ d = Affine(20, 10)
dt = tf(d)
@test d(xs) dt(xs)
@testset "Tuple I/O" begin
@test tf(@net x -> (identity(x),))([1,2,3]) == ([1,2,3],)
@test tf(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18]
end
@testset "Recurrence" begin
seq = batchone(Seq(rand(10) for i = 1:3))
r = unroll(Recurrent(10, 5), 3)
rm = tf(r)
@test r(seq) rm(seq)
end
test_tupleio(tf)
test_recurrence(tf)
test_stacktrace(tf)
@testset "Tensor interface" begin
sess = TensorFlow.Session()
@ -30,15 +22,4 @@ end
@test run(sess, Y, Dict(X=>xs)) d(xs)
end
@testset "Stack Traces" begin
model = TLP(Affine(10, 20), Affine(21, 15))
dm = tf(model)
e = try dm(rand(1, 10))
catch e e end
@test isa(e, DataFlow.Interpreter.Exception)
@test e.trace[1].func == Symbol("Flux.Affine")
@test e.trace[2].func == :TLP
end
end

View File

@ -1,3 +1,6 @@
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
syntax(x) = syntax(graph(x))
@testset "Basics" begin
xs = randn(1, 10)

View File

@ -2,9 +2,6 @@ using Flux, DataFlow, MacroTools, Base.Test
using Flux: graph, Param, unsqueeze
using DataFlow: Line, Frame
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
syntax(x) = syntax(graph(x))
macro mxonly(ex)
:(Base.find_in_path("MXNet") nothing && $(esc(ex)))
end
@ -13,16 +10,9 @@ macro tfonly(ex)
:(Base.find_in_path("TensorFlow") nothing && $(esc(ex)))
end
@net type TLP
first
second
function (x)
l1 = σ(first(x))
l2 = softmax(second(l1))
end
end
include("batching.jl")
include("backend/common.jl")
include("basic.jl")
include("recurrent.jl")
@tfonly include("backend/tensorflow.jl")