2017-02-23 22:28:18 +00:00
|
|
|
using MXNet
|
|
|
|
Flux.loadmx()
|
|
|
|
|
|
|
|
@testset "MXNet" begin
|
|
|
|
|
2017-04-18 20:04:21 +00:00
|
|
|
xs, ys = rand(1, 20), rand(1, 20)
|
2017-02-23 22:28:18 +00:00
|
|
|
d = Affine(20, 10)
|
|
|
|
|
2017-03-08 21:41:13 +00:00
|
|
|
dm = mxnet(d)
|
2017-02-23 22:28:18 +00:00
|
|
|
@test d(xs) ≈ dm(xs)
|
|
|
|
|
2017-05-04 12:52:31 +00:00
|
|
|
test_tupleio(mxnet)
|
|
|
|
test_recurrence(mxnet)
|
|
|
|
test_stacktrace(mxnet)
|
|
|
|
test_back(mxnet)
|
2017-05-30 16:23:34 +00:00
|
|
|
test_anon(mxnet)
|
2017-02-23 22:51:37 +00:00
|
|
|
|
2017-03-31 11:39:23 +00:00
|
|
|
@testset "Native interface" begin
|
2017-02-23 22:28:18 +00:00
|
|
|
f = mx.FeedForward(Chain(d, softmax))
|
|
|
|
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
|
|
|
|
|
|
|
|
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
|
|
|
|
flatten, Affine(1587, 10), softmax)
|
|
|
|
f = mx.FeedForward(m)
|
2017-02-23 22:51:37 +00:00
|
|
|
# TODO: test run
|
2017-02-23 22:28:18 +00:00
|
|
|
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
|
|
|
|
end
|
|
|
|
|
|
|
|
end
|