diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 70fc413f..1d074a19 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -18,6 +18,7 @@ node(x::Tuple) = map(node, x) node(x::mx.SymbolicNode) = x graph(::typeof(tuple), args...) = (args...,) +graph(::typeof(identity), x) = x graph(::typeof(getindex), t::Tuple, n::Integer) = t[n] graph(::typeof(.+), args...) = mx.broadcast_plus(args...) graph(::typeof(.*), args...) = mx.broadcast_mul(args...) diff --git a/test/backend/common.jl b/test/backend/common.jl new file mode 100644 index 00000000..1ed0122c --- /dev/null +++ b/test/backend/common.jl @@ -0,0 +1,57 @@ +@net type TLP + first + second + function (x) + l1 = σ(first(x)) + l2 = softmax(second(l1)) + end +end + +function test_tupleio(bk) + @testset "Tuple I/O" begin + val = [1,2,3] + tup = ([1,2,3],[4,5,6]) + @test bk(@net x -> (identity(x),))(val) == (val,) + @test bk(@net x -> x[1].*x[2])(tup) == [4,10,18] + end +end + +function test_recurrence(bk) + @testset "Recurrence" begin + seq = batchone(Seq(rand(10) for i = 1:3)) + r = unroll(Recurrent(10, 5), 3) + rm = bk(r) + @test r(seq) ≈ rm(seq) + end +end + +function test_back(bk) + @testset "Backward Pass" begin + xs, ys = rand(1, 20), rand(1, 20) + d = Affine(20, 10) + dm = bk(d) + d′ = deepcopy(d) + @test dm(xs) ≈ d(xs) + @test dm(xs) ≈ d′(xs) + + Δ = back!(dm, randn(1, 10), xs) + @test length(Δ[1]) == 20 + update!(dm, 0.1) + + @test dm(xs) ≈ d(xs) + @test dm(xs) ≉ d′(xs) + end +end + +function test_stacktrace(bk) + @testset "Stack Traces" begin + model = TLP(Affine(10, 20), Affine(21, 15)) + dm = bk(model) + e = try dm(rand(1, 10)) + catch e e end + + @test isa(e, DataFlow.Interpreter.Exception) + @test e.trace[1].func == Symbol("Flux.Affine") + @test e.trace[2].func == :TLP + end +end diff --git a/test/backend/mxnet.jl b/test/backend/mxnet.jl index 3bd1927f..29e51dd2 100644 --- a/test/backend/mxnet.jl +++ b/test/backend/mxnet.jl @@ -9,30 +9,10 @@ d = Affine(20, 10) dm = mxnet(d) @test d(xs) ≈ dm(xs) -@testset "Tuple I/O" begin - @test mxnet(@net x -> (x,))([1,2,3]) == ([1,2,3],) - @test mxnet(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18] -end - -@testset "Recurrence" begin - seq = batchone(Seq(rand(10) for i = 1:3)) - r = unroll(Recurrent(10, 5), 3) - rm = mxnet(r) - @test r(seq) ≈ rm(seq) -end - -@testset "Backward Pass" begin - d′ = deepcopy(d) - @test dm(xs) ≈ d(xs) - @test dm(xs) ≈ d′(xs) - - Δ = back!(dm, randn(1, 10), xs) - @test length(Δ[1]) == 20 - update!(dm, 0.1) - - @test dm(xs) ≈ d(xs) - @test dm(xs) ≉ d′(xs) -end +test_tupleio(mxnet) +test_recurrence(mxnet) +test_stacktrace(mxnet) +test_back(mxnet) @testset "Native interface" begin f = mx.FeedForward(Chain(d, softmax)) @@ -45,16 +25,4 @@ end @test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)] end -@testset "Stack Traces" begin - model = TLP(Affine(10, 20), Affine(21, 15)) - info("The following warning is normal") - dm = mxnet(model) - e = try dm(rand(1, 10)) - catch e e end - - @test isa(e, DataFlow.Interpreter.Exception) - @test e.trace[1].func == Symbol("Flux.Affine") - @test e.trace[2].func == :TLP -end - end diff --git a/test/backend/tensorflow.jl b/test/backend/tensorflow.jl index 5d75d210..ab9d4af7 100644 --- a/test/backend/tensorflow.jl +++ b/test/backend/tensorflow.jl @@ -9,17 +9,9 @@ d = Affine(20, 10) dt = tf(d) @test d(xs) ≈ dt(xs) -@testset "Tuple I/O" begin - @test tf(@net x -> (identity(x),))([1,2,3]) == ([1,2,3],) - @test tf(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18] -end - -@testset "Recurrence" begin - seq = batchone(Seq(rand(10) for i = 1:3)) - r = unroll(Recurrent(10, 5), 3) - rm = tf(r) - @test r(seq) ≈ rm(seq) -end +test_tupleio(tf) +test_recurrence(tf) +test_stacktrace(tf) @testset "Tensor interface" begin sess = TensorFlow.Session() @@ -30,15 +22,4 @@ end @test run(sess, Y, Dict(X=>xs)) ≈ d(xs) end -@testset "Stack Traces" begin - model = TLP(Affine(10, 20), Affine(21, 15)) - dm = tf(model) - e = try dm(rand(1, 10)) - catch e e end - - @test isa(e, DataFlow.Interpreter.Exception) - @test e.trace[1].func == Symbol("Flux.Affine") - @test e.trace[2].func == :TLP -end - end diff --git a/test/basic.jl b/test/basic.jl index 708e8f83..56ef16d5 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1,3 +1,6 @@ +syntax(v::Vertex) = prettify(DataFlow.syntax(v)) +syntax(x) = syntax(graph(x)) + @testset "Basics" begin xs = randn(1, 10) diff --git a/test/runtests.jl b/test/runtests.jl index e52c33cd..e110725e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,9 +2,6 @@ using Flux, DataFlow, MacroTools, Base.Test using Flux: graph, Param, unsqueeze using DataFlow: Line, Frame -syntax(v::Vertex) = prettify(DataFlow.syntax(v)) -syntax(x) = syntax(graph(x)) - macro mxonly(ex) :(Base.find_in_path("MXNet") ≠ nothing && $(esc(ex))) end @@ -13,16 +10,9 @@ macro tfonly(ex) :(Base.find_in_path("TensorFlow") ≠ nothing && $(esc(ex))) end -@net type TLP - first - second - function (x) - l1 = σ(first(x)) - l2 = softmax(second(l1)) - end -end - include("batching.jl") +include("backend/common.jl") + include("basic.jl") include("recurrent.jl") @tfonly include("backend/tensorflow.jl")