2017-02-23 22:28:18 +00:00
|
|
|
|
using MXNet
|
|
|
|
|
Flux.loadmx()
|
|
|
|
|
|
|
|
|
|
@testset "MXNet" begin
|
|
|
|
|
|
2017-04-18 20:04:21 +00:00
|
|
|
|
xs, ys = rand(1, 20), rand(1, 20)
|
2017-02-23 22:28:18 +00:00
|
|
|
|
d = Affine(20, 10)
|
|
|
|
|
|
2017-03-08 21:41:13 +00:00
|
|
|
|
dm = mxnet(d)
|
2017-02-23 22:28:18 +00:00
|
|
|
|
@test d(xs) ≈ dm(xs)
|
|
|
|
|
|
2017-03-30 18:50:03 +00:00
|
|
|
|
m = Multi(20, 15)
|
|
|
|
|
mm = mxnet(m)
|
|
|
|
|
@test all(isapprox.(mm(xs, ys), m(xs, ys)))
|
2017-03-06 17:20:15 +00:00
|
|
|
|
|
2017-03-31 11:39:23 +00:00
|
|
|
|
@testset "Recurrence" begin
|
2017-04-18 20:04:21 +00:00
|
|
|
|
seq = batchone(Seq(rand(10) for i = 1:3))
|
2017-03-31 11:39:23 +00:00
|
|
|
|
r = unroll(Recurrent(10, 5), 3)
|
|
|
|
|
rm = mxnet(r)
|
|
|
|
|
@test r(seq) ≈ rm(seq)
|
|
|
|
|
end
|
|
|
|
|
|
2017-02-23 22:51:37 +00:00
|
|
|
|
@testset "Backward Pass" begin
|
|
|
|
|
d′ = deepcopy(d)
|
|
|
|
|
@test dm(xs) ≈ d(xs)
|
|
|
|
|
@test dm(xs) ≈ d′(xs)
|
|
|
|
|
|
2017-04-18 20:04:21 +00:00
|
|
|
|
Δ = back!(dm, randn(1, 10), xs)
|
2017-03-30 18:36:59 +00:00
|
|
|
|
@test length(Δ[1]) == 20
|
2017-02-23 22:51:37 +00:00
|
|
|
|
update!(dm, 0.1)
|
|
|
|
|
|
|
|
|
|
@test dm(xs) ≈ d(xs)
|
|
|
|
|
@test dm(xs) ≉ d′(xs)
|
|
|
|
|
end
|
|
|
|
|
|
2017-03-31 11:39:23 +00:00
|
|
|
|
@testset "Native interface" begin
|
2017-02-23 22:28:18 +00:00
|
|
|
|
f = mx.FeedForward(Chain(d, softmax))
|
|
|
|
|
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
|
|
|
|
|
|
|
|
|
|
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
|
|
|
|
|
flatten, Affine(1587, 10), softmax)
|
|
|
|
|
f = mx.FeedForward(m)
|
2017-02-23 22:51:37 +00:00
|
|
|
|
# TODO: test run
|
2017-02-23 22:28:18 +00:00
|
|
|
|
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
@testset "Stack Traces" begin
|
|
|
|
|
model = TLP(Affine(10, 20), Affine(21, 15))
|
|
|
|
|
info("The following warning is normal")
|
2017-03-08 21:41:13 +00:00
|
|
|
|
dm = mxnet(model)
|
2017-04-18 20:04:21 +00:00
|
|
|
|
e = try dm(rand(1, 10))
|
2017-02-23 22:28:18 +00:00
|
|
|
|
catch e e end
|
|
|
|
|
|
|
|
|
|
@test isa(e, DataFlow.Interpreter.Exception)
|
|
|
|
|
@test e.trace[1].func == Symbol("Flux.Affine")
|
|
|
|
|
@test e.trace[2].func == :TLP
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
end
|