organise tests
This commit is contained in:
parent
0222103c7f
commit
e05fea7eb4
@ -1,47 +0,0 @@
|
|||||||
@net type TLP
|
|
||||||
first
|
|
||||||
second
|
|
||||||
function (x)
|
|
||||||
l1 = σ(first(x))
|
|
||||||
l2 = softmax(second(l1))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function test_tupleio(bk)
|
|
||||||
@testset "Tuple I/O" begin
|
|
||||||
val = [1,2,3]
|
|
||||||
tup = ([1,2,3],[4,5,6])
|
|
||||||
@test bk(@net x -> (identity(x),))(val) == (val,)
|
|
||||||
@test bk(@net x -> x[1].*x[2])(tup) == [4,10,18]
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function test_recurrence(bk)
|
|
||||||
@testset "Recurrence" begin
|
|
||||||
seq = unsqueeze(stack(rand(10) for i = 1:3))
|
|
||||||
r = Flux.Compiler.unroll(Recurrent(10, 5), 3)
|
|
||||||
rm = bk(r)
|
|
||||||
@test r(seq) ≈ rm(seq)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function test_stacktrace(bk)
|
|
||||||
@testset "Stack Traces" begin
|
|
||||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
|
||||||
dm = bk(model)
|
|
||||||
e = try dm(rand(1, 10))
|
|
||||||
catch e e end
|
|
||||||
|
|
||||||
@test isa(e, DataFlow.Interpreter.Exception)
|
|
||||||
@test e.trace[1].func == Symbol("Flux.Affine")
|
|
||||||
@test e.trace[2].func == :TLP
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function test_anon(bk)
|
|
||||||
@testset "Closures" begin
|
|
||||||
x, y = rand(3), rand(5)
|
|
||||||
model = bk(@net xs -> map(x -> x .* x, xs))
|
|
||||||
@test all(model((x, y)) .≈ (x.*x, y.*y))
|
|
||||||
end
|
|
||||||
end
|
|
@ -1,39 +0,0 @@
|
|||||||
using Flux: Affine
|
|
||||||
|
|
||||||
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
|
||||||
syntax(x) = syntax(graph(x))
|
|
||||||
|
|
||||||
@testset "Basics" begin
|
|
||||||
|
|
||||||
xs = randn(1, 10)
|
|
||||||
d = Affine(10, 20)
|
|
||||||
|
|
||||||
@test d(xs) ≈ (xs*d.W.x + d.b.x)
|
|
||||||
|
|
||||||
d1 = @net x -> x * d.W + d.b
|
|
||||||
|
|
||||||
# Skip this before new DataFlow is released.
|
|
||||||
# let
|
|
||||||
# @test @capture(syntax(d), _Frame(_Line((+).(x_[1] * W_, b_))))
|
|
||||||
# @test isa(x, DataFlow.Input) && isa(W, Param) && isa(b, Param)
|
|
||||||
# end
|
|
||||||
|
|
||||||
test_anon(identity)
|
|
||||||
|
|
||||||
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
|
||||||
tlp = TLP(a1, a2)
|
|
||||||
@test tlp(xs) ≈ softmax(a2(σ(a1(xs))))
|
|
||||||
@test Flux.Compiler.interpmodel(tlp, xs) ≈ softmax(a2(σ(a1(xs))))
|
|
||||||
end
|
|
||||||
|
|
||||||
let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
|
||||||
e = try
|
|
||||||
Flux.Compiler.interpmodel(tlp, rand(1, 10))
|
|
||||||
catch e
|
|
||||||
e
|
|
||||||
end
|
|
||||||
@test e.trace[end].func == :TLP
|
|
||||||
@test e.trace[end-1].func == Symbol("Flux.Affine")
|
|
||||||
end
|
|
||||||
|
|
||||||
end
|
|
66
test/compiler.jl
Normal file
66
test/compiler.jl
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
using DataFlow, MacroTools
|
||||||
|
using Flux: Affine, Param, Recurrent, squeeze, unsqueeze, stack
|
||||||
|
using Flux.Compiler: @net, graph
|
||||||
|
using DataFlow: Line, Frame
|
||||||
|
|
||||||
|
@net type TLP
|
||||||
|
first
|
||||||
|
second
|
||||||
|
function (x)
|
||||||
|
l1 = σ(first(x))
|
||||||
|
l2 = softmax(second(l1))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
||||||
|
syntax(x) = syntax(graph(x))
|
||||||
|
|
||||||
|
@testset "Compiler" begin
|
||||||
|
|
||||||
|
xs = randn(1, 10)
|
||||||
|
d = Affine(10, 20)
|
||||||
|
|
||||||
|
@test d(xs) ≈ (xs*d.W.x + d.b.x)
|
||||||
|
|
||||||
|
d1 = @net x -> x * d.W + d.b
|
||||||
|
|
||||||
|
let
|
||||||
|
@capture(syntax(d), _Frame(_Line((+).(x_[1] * W_, b_))))
|
||||||
|
@test isa(x, DataFlow.Input) && isa(W, Param) && isa(b, Param)
|
||||||
|
end
|
||||||
|
|
||||||
|
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
||||||
|
tlp = TLP(a1, a2)
|
||||||
|
@test tlp(xs) ≈ softmax(a2(σ(a1(xs))))
|
||||||
|
@test Flux.Compiler.interpmodel(tlp, xs) ≈ softmax(a2(σ(a1(xs))))
|
||||||
|
end
|
||||||
|
|
||||||
|
let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
||||||
|
e = try
|
||||||
|
Flux.Compiler.interpmodel(tlp, rand(1, 10))
|
||||||
|
catch e
|
||||||
|
e
|
||||||
|
end
|
||||||
|
@test e.trace[end].func == :TLP
|
||||||
|
@test e.trace[end-1].func == Symbol("Flux.Affine")
|
||||||
|
end
|
||||||
|
|
||||||
|
function apply(model, xs, state)
|
||||||
|
ys = similar(xs, 0)
|
||||||
|
for x in xs
|
||||||
|
state, y = model(state, x)
|
||||||
|
push!(ys, y)
|
||||||
|
end
|
||||||
|
state, ys
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "RNN unrolling" begin
|
||||||
|
r = Recurrent(10, 5)
|
||||||
|
xs = [rand(1, 10) for _ = 1:3]
|
||||||
|
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y.x,))
|
||||||
|
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
|
||||||
|
ru = Flux.Compiler.unroll(r, 3)
|
||||||
|
ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys)
|
||||||
|
end
|
||||||
|
|
||||||
|
end
|
@ -1,19 +0,0 @@
|
|||||||
using Flux: Recurrent
|
|
||||||
|
|
||||||
function apply(model, xs, state)
|
|
||||||
ys = similar(xs, 0)
|
|
||||||
for x in xs
|
|
||||||
state, y = model(state, x)
|
|
||||||
push!(ys, y)
|
|
||||||
end
|
|
||||||
state, ys
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "RNN unrolling" begin
|
|
||||||
r = Recurrent(10, 5)
|
|
||||||
xs = [rand(1, 10) for _ = 1:3]
|
|
||||||
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y.x,))
|
|
||||||
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
|
|
||||||
ru = Flux.Compiler.unroll(r, 3)
|
|
||||||
ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys)
|
|
||||||
end
|
|
@ -1,14 +1,8 @@
|
|||||||
using Flux, DataFlow, MacroTools, Base.Test
|
using Flux, Base.Test
|
||||||
using Flux: Param, param, squeeze, unsqueeze, stack, update!, flatten
|
|
||||||
using Flux.Compiler: @net
|
|
||||||
using DataFlow: Line, Frame
|
|
||||||
|
|
||||||
@testset "Flux" begin
|
@testset "Flux" begin
|
||||||
|
|
||||||
include("backend/common.jl")
|
include("compiler.jl")
|
||||||
|
include("utils.jl")
|
||||||
include("basic.jl")
|
|
||||||
include("recurrent.jl")
|
|
||||||
include("throttle.jl")
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
using Flux.throttle
|
using Flux: throttle
|
||||||
|
|
||||||
@testset "throttle" begin
|
@testset "Throttle" begin
|
||||||
@testset "default behaviour" begin
|
@testset "default behaviour" begin
|
||||||
a = []
|
a = []
|
||||||
f = throttle(()->push!(a, now()), 1, leading=true, trailing=false)
|
f = throttle(()->push!(a, now()), 1, leading=true, trailing=false)
|
Loading…
Reference in New Issue
Block a user