more mxnet integration checks
This commit is contained in:
parent
092f2038b3
commit
15816bdbaf
@ -8,10 +8,15 @@ d = Affine(20, 10)
|
|||||||
end
|
end
|
||||||
|
|
||||||
@mxonly let
|
@mxonly let
|
||||||
|
# TODO: test run
|
||||||
using MXNet
|
using MXNet
|
||||||
f = mx.FeedForward(Chain(d, softmax))
|
f = mx.FeedForward(Chain(d, softmax))
|
||||||
@test isa(f, mx.FeedForward)
|
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
|
||||||
# TODO: test run
|
|
||||||
|
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
|
||||||
|
flatten, Affine(1587, 10), softmax)
|
||||||
|
f = mx.FeedForward(m)
|
||||||
|
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
|
||||||
end
|
end
|
||||||
|
|
||||||
# TensorFlow
|
# TensorFlow
|
||||||
|
@ -6,7 +6,7 @@ d = Affine(10, 20)
|
|||||||
|
|
||||||
let
|
let
|
||||||
@capture(syntax(d), _Frame(_Line(x_[1] * W_ + b_)))
|
@capture(syntax(d), _Frame(_Line(x_[1] * W_ + b_)))
|
||||||
@test isa(x, Input) && isa(W, Param) && isa(b, Param)
|
@test isa(x, DataFlow.Input) && isa(W, Param) && isa(b, Param)
|
||||||
end
|
end
|
||||||
|
|
||||||
@net type TLP
|
@net type TLP
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
using Flux, DataFlow, MacroTools, Base.Test
|
using Flux, DataFlow, MacroTools, Base.Test
|
||||||
using Flux: graph, Param
|
using Flux: graph, Param
|
||||||
using DataFlow: Input, Line, Frame
|
using DataFlow: Line, Frame
|
||||||
|
|
||||||
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
||||||
syntax(x) = syntax(graph(x))
|
syntax(x) = syntax(graph(x))
|
||||||
|
Loading…
Reference in New Issue
Block a user