factor out common tests
This commit is contained in:
parent
70286c0bf3
commit
a2db4b5319
@ -18,6 +18,7 @@ node(x::Tuple) = map(node, x)
|
|||||||
node(x::mx.SymbolicNode) = x
|
node(x::mx.SymbolicNode) = x
|
||||||
|
|
||||||
graph(::typeof(tuple), args...) = (args...,)
|
graph(::typeof(tuple), args...) = (args...,)
|
||||||
|
graph(::typeof(identity), x) = x
|
||||||
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
|
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
|
||||||
graph(::typeof(.+), args...) = mx.broadcast_plus(args...)
|
graph(::typeof(.+), args...) = mx.broadcast_plus(args...)
|
||||||
graph(::typeof(.*), args...) = mx.broadcast_mul(args...)
|
graph(::typeof(.*), args...) = mx.broadcast_mul(args...)
|
||||||
|
57
test/backend/common.jl
Normal file
57
test/backend/common.jl
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
@net type TLP
|
||||||
|
first
|
||||||
|
second
|
||||||
|
function (x)
|
||||||
|
l1 = σ(first(x))
|
||||||
|
l2 = softmax(second(l1))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function test_tupleio(bk)
|
||||||
|
@testset "Tuple I/O" begin
|
||||||
|
val = [1,2,3]
|
||||||
|
tup = ([1,2,3],[4,5,6])
|
||||||
|
@test bk(@net x -> (identity(x),))(val) == (val,)
|
||||||
|
@test bk(@net x -> x[1].*x[2])(tup) == [4,10,18]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function test_recurrence(bk)
|
||||||
|
@testset "Recurrence" begin
|
||||||
|
seq = batchone(Seq(rand(10) for i = 1:3))
|
||||||
|
r = unroll(Recurrent(10, 5), 3)
|
||||||
|
rm = bk(r)
|
||||||
|
@test r(seq) ≈ rm(seq)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function test_back(bk)
|
||||||
|
@testset "Backward Pass" begin
|
||||||
|
xs, ys = rand(1, 20), rand(1, 20)
|
||||||
|
d = Affine(20, 10)
|
||||||
|
dm = bk(d)
|
||||||
|
d′ = deepcopy(d)
|
||||||
|
@test dm(xs) ≈ d(xs)
|
||||||
|
@test dm(xs) ≈ d′(xs)
|
||||||
|
|
||||||
|
Δ = back!(dm, randn(1, 10), xs)
|
||||||
|
@test length(Δ[1]) == 20
|
||||||
|
update!(dm, 0.1)
|
||||||
|
|
||||||
|
@test dm(xs) ≈ d(xs)
|
||||||
|
@test dm(xs) ≉ d′(xs)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function test_stacktrace(bk)
|
||||||
|
@testset "Stack Traces" begin
|
||||||
|
model = TLP(Affine(10, 20), Affine(21, 15))
|
||||||
|
dm = bk(model)
|
||||||
|
e = try dm(rand(1, 10))
|
||||||
|
catch e e end
|
||||||
|
|
||||||
|
@test isa(e, DataFlow.Interpreter.Exception)
|
||||||
|
@test e.trace[1].func == Symbol("Flux.Affine")
|
||||||
|
@test e.trace[2].func == :TLP
|
||||||
|
end
|
||||||
|
end
|
@ -9,30 +9,10 @@ d = Affine(20, 10)
|
|||||||
dm = mxnet(d)
|
dm = mxnet(d)
|
||||||
@test d(xs) ≈ dm(xs)
|
@test d(xs) ≈ dm(xs)
|
||||||
|
|
||||||
@testset "Tuple I/O" begin
|
test_tupleio(mxnet)
|
||||||
@test mxnet(@net x -> (x,))([1,2,3]) == ([1,2,3],)
|
test_recurrence(mxnet)
|
||||||
@test mxnet(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18]
|
test_stacktrace(mxnet)
|
||||||
end
|
test_back(mxnet)
|
||||||
|
|
||||||
@testset "Recurrence" begin
|
|
||||||
seq = batchone(Seq(rand(10) for i = 1:3))
|
|
||||||
r = unroll(Recurrent(10, 5), 3)
|
|
||||||
rm = mxnet(r)
|
|
||||||
@test r(seq) ≈ rm(seq)
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Backward Pass" begin
|
|
||||||
d′ = deepcopy(d)
|
|
||||||
@test dm(xs) ≈ d(xs)
|
|
||||||
@test dm(xs) ≈ d′(xs)
|
|
||||||
|
|
||||||
Δ = back!(dm, randn(1, 10), xs)
|
|
||||||
@test length(Δ[1]) == 20
|
|
||||||
update!(dm, 0.1)
|
|
||||||
|
|
||||||
@test dm(xs) ≈ d(xs)
|
|
||||||
@test dm(xs) ≉ d′(xs)
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Native interface" begin
|
@testset "Native interface" begin
|
||||||
f = mx.FeedForward(Chain(d, softmax))
|
f = mx.FeedForward(Chain(d, softmax))
|
||||||
@ -45,16 +25,4 @@ end
|
|||||||
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
|
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Stack Traces" begin
|
|
||||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
|
||||||
info("The following warning is normal")
|
|
||||||
dm = mxnet(model)
|
|
||||||
e = try dm(rand(1, 10))
|
|
||||||
catch e e end
|
|
||||||
|
|
||||||
@test isa(e, DataFlow.Interpreter.Exception)
|
|
||||||
@test e.trace[1].func == Symbol("Flux.Affine")
|
|
||||||
@test e.trace[2].func == :TLP
|
|
||||||
end
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -9,17 +9,9 @@ d = Affine(20, 10)
|
|||||||
dt = tf(d)
|
dt = tf(d)
|
||||||
@test d(xs) ≈ dt(xs)
|
@test d(xs) ≈ dt(xs)
|
||||||
|
|
||||||
@testset "Tuple I/O" begin
|
test_tupleio(tf)
|
||||||
@test tf(@net x -> (identity(x),))([1,2,3]) == ([1,2,3],)
|
test_recurrence(tf)
|
||||||
@test tf(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18]
|
test_stacktrace(tf)
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Recurrence" begin
|
|
||||||
seq = batchone(Seq(rand(10) for i = 1:3))
|
|
||||||
r = unroll(Recurrent(10, 5), 3)
|
|
||||||
rm = tf(r)
|
|
||||||
@test r(seq) ≈ rm(seq)
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Tensor interface" begin
|
@testset "Tensor interface" begin
|
||||||
sess = TensorFlow.Session()
|
sess = TensorFlow.Session()
|
||||||
@ -30,15 +22,4 @@ end
|
|||||||
@test run(sess, Y, Dict(X=>xs)) ≈ d(xs)
|
@test run(sess, Y, Dict(X=>xs)) ≈ d(xs)
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Stack Traces" begin
|
|
||||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
|
||||||
dm = tf(model)
|
|
||||||
e = try dm(rand(1, 10))
|
|
||||||
catch e e end
|
|
||||||
|
|
||||||
@test isa(e, DataFlow.Interpreter.Exception)
|
|
||||||
@test e.trace[1].func == Symbol("Flux.Affine")
|
|
||||||
@test e.trace[2].func == :TLP
|
|
||||||
end
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
||||||
|
syntax(x) = syntax(graph(x))
|
||||||
|
|
||||||
@testset "Basics" begin
|
@testset "Basics" begin
|
||||||
|
|
||||||
xs = randn(1, 10)
|
xs = randn(1, 10)
|
||||||
|
@ -2,9 +2,6 @@ using Flux, DataFlow, MacroTools, Base.Test
|
|||||||
using Flux: graph, Param, unsqueeze
|
using Flux: graph, Param, unsqueeze
|
||||||
using DataFlow: Line, Frame
|
using DataFlow: Line, Frame
|
||||||
|
|
||||||
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
|
||||||
syntax(x) = syntax(graph(x))
|
|
||||||
|
|
||||||
macro mxonly(ex)
|
macro mxonly(ex)
|
||||||
:(Base.find_in_path("MXNet") ≠ nothing && $(esc(ex)))
|
:(Base.find_in_path("MXNet") ≠ nothing && $(esc(ex)))
|
||||||
end
|
end
|
||||||
@ -13,16 +10,9 @@ macro tfonly(ex)
|
|||||||
:(Base.find_in_path("TensorFlow") ≠ nothing && $(esc(ex)))
|
:(Base.find_in_path("TensorFlow") ≠ nothing && $(esc(ex)))
|
||||||
end
|
end
|
||||||
|
|
||||||
@net type TLP
|
|
||||||
first
|
|
||||||
second
|
|
||||||
function (x)
|
|
||||||
l1 = σ(first(x))
|
|
||||||
l2 = softmax(second(l1))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
include("batching.jl")
|
include("batching.jl")
|
||||||
|
include("backend/common.jl")
|
||||||
|
|
||||||
include("basic.jl")
|
include("basic.jl")
|
||||||
include("recurrent.jl")
|
include("recurrent.jl")
|
||||||
@tfonly include("backend/tensorflow.jl")
|
@tfonly include("backend/tensorflow.jl")
|
||||||
|
Loading…
Reference in New Issue
Block a user