mxnet integration test
This commit is contained in:
parent
ad4d60f90d
commit
a14cbf301d
@ -1,15 +1,24 @@
|
|||||||
xs = rand(20)
|
xs = rand(20)
|
||||||
d = Affine(20, 10)
|
d = Affine(20, 10)
|
||||||
|
|
||||||
@tfonly let dt = tf(d)
|
# MXNet
|
||||||
@test d(xs) ≈ dt(xs)
|
|
||||||
end
|
|
||||||
|
|
||||||
@mxonly let dm = mxnet(d, (1, 20))
|
@mxonly let dm = mxnet(d, (1, 20))
|
||||||
@test d(xs) ≈ dm(xs)
|
@test d(xs) ≈ dm(xs)
|
||||||
end
|
end
|
||||||
|
|
||||||
# TensorFlow native integration
|
@mxonly let
|
||||||
|
using MXNet
|
||||||
|
f = mx.FeedForward(Chain(d, softmax))
|
||||||
|
@test isa(f, mx.FeedForward)
|
||||||
|
# TODO: test run
|
||||||
|
end
|
||||||
|
|
||||||
|
# TensorFlow
|
||||||
|
|
||||||
|
@tfonly let dt = tf(d)
|
||||||
|
@test d(xs) ≈ dt(xs)
|
||||||
|
end
|
||||||
|
|
||||||
@tfonly let
|
@tfonly let
|
||||||
using TensorFlow
|
using TensorFlow
|
||||||
|
Loading…
Reference in New Issue
Block a user