more mxnet integration checks

This commit is contained in:
Mike J Innes 2017-02-21 15:46:38 +00:00
parent 092f2038b3
commit 15816bdbaf
3 changed files with 9 additions and 4 deletions

View File

@ -8,10 +8,15 @@ d = Affine(20, 10)
end
@mxonly let
# TODO: test run
using MXNet
f = mx.FeedForward(Chain(d, softmax))
@test isa(f, mx.FeedForward)
# TODO: test run
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
flatten, Affine(1587, 10), softmax)
f = mx.FeedForward(m)
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
end
# TensorFlow

View File

@ -6,7 +6,7 @@ d = Affine(10, 20)
let
@capture(syntax(d), _Frame(_Line(x_[1] * W_ + b_)))
@test isa(x, Input) && isa(W, Param) && isa(b, Param)
@test isa(x, DataFlow.Input) && isa(W, Param) && isa(b, Param)
end
@net type TLP

View File

@ -1,6 +1,6 @@
using Flux, DataFlow, MacroTools, Base.Test
using Flux: graph, Param
using DataFlow: Input, Line, Frame
using DataFlow: Line, Frame
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
syntax(x) = syntax(graph(x))