mxnet integration test
This commit is contained in:
parent
ad4d60f90d
commit
a14cbf301d
|
@ -1,15 +1,24 @@
|
|||
xs = rand(20)
|
||||
d = Affine(20, 10)
|
||||
|
||||
@tfonly let dt = tf(d)
|
||||
@test d(xs) ≈ dt(xs)
|
||||
end
|
||||
# MXNet
|
||||
|
||||
@mxonly let dm = mxnet(d, (1, 20))
|
||||
@test d(xs) ≈ dm(xs)
|
||||
end
|
||||
|
||||
# TensorFlow native integration
|
||||
@mxonly let
|
||||
using MXNet
|
||||
f = mx.FeedForward(Chain(d, softmax))
|
||||
@test isa(f, mx.FeedForward)
|
||||
# TODO: test run
|
||||
end
|
||||
|
||||
# TensorFlow
|
||||
|
||||
@tfonly let dt = tf(d)
|
||||
@test d(xs) ≈ dt(xs)
|
||||
end
|
||||
|
||||
@tfonly let
|
||||
using TensorFlow
|
||||
|
|
Loading…
Reference in New Issue