factor out common tests
This commit is contained in:
parent
70286c0bf3
commit
a2db4b5319
@ -18,6 +18,7 @@ node(x::Tuple) = map(node, x)
|
||||
node(x::mx.SymbolicNode) = x
|
||||
|
||||
graph(::typeof(tuple), args...) = (args...,)
|
||||
graph(::typeof(identity), x) = x
|
||||
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
|
||||
graph(::typeof(.+), args...) = mx.broadcast_plus(args...)
|
||||
graph(::typeof(.*), args...) = mx.broadcast_mul(args...)
|
||||
|
57
test/backend/common.jl
Normal file
57
test/backend/common.jl
Normal file
@ -0,0 +1,57 @@
|
||||
@net type TLP
|
||||
first
|
||||
second
|
||||
function (x)
|
||||
l1 = σ(first(x))
|
||||
l2 = softmax(second(l1))
|
||||
end
|
||||
end
|
||||
|
||||
function test_tupleio(bk)
|
||||
@testset "Tuple I/O" begin
|
||||
val = [1,2,3]
|
||||
tup = ([1,2,3],[4,5,6])
|
||||
@test bk(@net x -> (identity(x),))(val) == (val,)
|
||||
@test bk(@net x -> x[1].*x[2])(tup) == [4,10,18]
|
||||
end
|
||||
end
|
||||
|
||||
function test_recurrence(bk)
|
||||
@testset "Recurrence" begin
|
||||
seq = batchone(Seq(rand(10) for i = 1:3))
|
||||
r = unroll(Recurrent(10, 5), 3)
|
||||
rm = bk(r)
|
||||
@test r(seq) ≈ rm(seq)
|
||||
end
|
||||
end
|
||||
|
||||
function test_back(bk)
|
||||
@testset "Backward Pass" begin
|
||||
xs, ys = rand(1, 20), rand(1, 20)
|
||||
d = Affine(20, 10)
|
||||
dm = bk(d)
|
||||
d′ = deepcopy(d)
|
||||
@test dm(xs) ≈ d(xs)
|
||||
@test dm(xs) ≈ d′(xs)
|
||||
|
||||
Δ = back!(dm, randn(1, 10), xs)
|
||||
@test length(Δ[1]) == 20
|
||||
update!(dm, 0.1)
|
||||
|
||||
@test dm(xs) ≈ d(xs)
|
||||
@test dm(xs) ≉ d′(xs)
|
||||
end
|
||||
end
|
||||
|
||||
function test_stacktrace(bk)
|
||||
@testset "Stack Traces" begin
|
||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
||||
dm = bk(model)
|
||||
e = try dm(rand(1, 10))
|
||||
catch e e end
|
||||
|
||||
@test isa(e, DataFlow.Interpreter.Exception)
|
||||
@test e.trace[1].func == Symbol("Flux.Affine")
|
||||
@test e.trace[2].func == :TLP
|
||||
end
|
||||
end
|
@ -9,30 +9,10 @@ d = Affine(20, 10)
|
||||
dm = mxnet(d)
|
||||
@test d(xs) ≈ dm(xs)
|
||||
|
||||
@testset "Tuple I/O" begin
|
||||
@test mxnet(@net x -> (x,))([1,2,3]) == ([1,2,3],)
|
||||
@test mxnet(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18]
|
||||
end
|
||||
|
||||
@testset "Recurrence" begin
|
||||
seq = batchone(Seq(rand(10) for i = 1:3))
|
||||
r = unroll(Recurrent(10, 5), 3)
|
||||
rm = mxnet(r)
|
||||
@test r(seq) ≈ rm(seq)
|
||||
end
|
||||
|
||||
@testset "Backward Pass" begin
|
||||
d′ = deepcopy(d)
|
||||
@test dm(xs) ≈ d(xs)
|
||||
@test dm(xs) ≈ d′(xs)
|
||||
|
||||
Δ = back!(dm, randn(1, 10), xs)
|
||||
@test length(Δ[1]) == 20
|
||||
update!(dm, 0.1)
|
||||
|
||||
@test dm(xs) ≈ d(xs)
|
||||
@test dm(xs) ≉ d′(xs)
|
||||
end
|
||||
test_tupleio(mxnet)
|
||||
test_recurrence(mxnet)
|
||||
test_stacktrace(mxnet)
|
||||
test_back(mxnet)
|
||||
|
||||
@testset "Native interface" begin
|
||||
f = mx.FeedForward(Chain(d, softmax))
|
||||
@ -45,16 +25,4 @@ end
|
||||
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
|
||||
end
|
||||
|
||||
@testset "Stack Traces" begin
|
||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
||||
info("The following warning is normal")
|
||||
dm = mxnet(model)
|
||||
e = try dm(rand(1, 10))
|
||||
catch e e end
|
||||
|
||||
@test isa(e, DataFlow.Interpreter.Exception)
|
||||
@test e.trace[1].func == Symbol("Flux.Affine")
|
||||
@test e.trace[2].func == :TLP
|
||||
end
|
||||
|
||||
end
|
||||
|
@ -9,17 +9,9 @@ d = Affine(20, 10)
|
||||
dt = tf(d)
|
||||
@test d(xs) ≈ dt(xs)
|
||||
|
||||
@testset "Tuple I/O" begin
|
||||
@test tf(@net x -> (identity(x),))([1,2,3]) == ([1,2,3],)
|
||||
@test tf(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18]
|
||||
end
|
||||
|
||||
@testset "Recurrence" begin
|
||||
seq = batchone(Seq(rand(10) for i = 1:3))
|
||||
r = unroll(Recurrent(10, 5), 3)
|
||||
rm = tf(r)
|
||||
@test r(seq) ≈ rm(seq)
|
||||
end
|
||||
test_tupleio(tf)
|
||||
test_recurrence(tf)
|
||||
test_stacktrace(tf)
|
||||
|
||||
@testset "Tensor interface" begin
|
||||
sess = TensorFlow.Session()
|
||||
@ -30,15 +22,4 @@ end
|
||||
@test run(sess, Y, Dict(X=>xs)) ≈ d(xs)
|
||||
end
|
||||
|
||||
@testset "Stack Traces" begin
|
||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
||||
dm = tf(model)
|
||||
e = try dm(rand(1, 10))
|
||||
catch e e end
|
||||
|
||||
@test isa(e, DataFlow.Interpreter.Exception)
|
||||
@test e.trace[1].func == Symbol("Flux.Affine")
|
||||
@test e.trace[2].func == :TLP
|
||||
end
|
||||
|
||||
end
|
||||
|
@ -1,3 +1,6 @@
|
||||
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
||||
syntax(x) = syntax(graph(x))
|
||||
|
||||
@testset "Basics" begin
|
||||
|
||||
xs = randn(1, 10)
|
||||
|
@ -2,9 +2,6 @@ using Flux, DataFlow, MacroTools, Base.Test
|
||||
using Flux: graph, Param, unsqueeze
|
||||
using DataFlow: Line, Frame
|
||||
|
||||
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
||||
syntax(x) = syntax(graph(x))
|
||||
|
||||
macro mxonly(ex)
|
||||
:(Base.find_in_path("MXNet") ≠ nothing && $(esc(ex)))
|
||||
end
|
||||
@ -13,16 +10,9 @@ macro tfonly(ex)
|
||||
:(Base.find_in_path("TensorFlow") ≠ nothing && $(esc(ex)))
|
||||
end
|
||||
|
||||
@net type TLP
|
||||
first
|
||||
second
|
||||
function (x)
|
||||
l1 = σ(first(x))
|
||||
l2 = softmax(second(l1))
|
||||
end
|
||||
end
|
||||
|
||||
include("batching.jl")
|
||||
include("backend/common.jl")
|
||||
|
||||
include("basic.jl")
|
||||
include("recurrent.jl")
|
||||
@tfonly include("backend/tensorflow.jl")
|
||||
|
Loading…
Reference in New Issue
Block a user