Flux.jl/test/backend/mxnet.jl

61 lines
1.4 KiB
Julia
Raw Normal View History

2017-02-23 22:28:18 +00:00
using MXNet
Flux.loadmx()
@testset "MXNet" begin
2017-04-18 20:04:21 +00:00
xs, ys = rand(1, 20), rand(1, 20)
2017-02-23 22:28:18 +00:00
d = Affine(20, 10)
2017-03-08 21:41:13 +00:00
dm = mxnet(d)
2017-02-23 22:28:18 +00:00
@test d(xs) dm(xs)
2017-05-01 16:44:20 +00:00
@testset "Tuple I/O" begin
@test mxnet(@net x -> (x,))([1,2,3]) == ([1,2,3],)
@test mxnet(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18]
end
2017-03-31 11:39:23 +00:00
@testset "Recurrence" begin
2017-04-18 20:04:21 +00:00
seq = batchone(Seq(rand(10) for i = 1:3))
2017-03-31 11:39:23 +00:00
r = unroll(Recurrent(10, 5), 3)
rm = mxnet(r)
@test r(seq) rm(seq)
end
2017-02-23 22:51:37 +00:00
@testset "Backward Pass" begin
d = deepcopy(d)
@test dm(xs) d(xs)
@test dm(xs) d(xs)
2017-04-18 20:04:21 +00:00
Δ = back!(dm, randn(1, 10), xs)
2017-03-30 18:36:59 +00:00
@test length(Δ[1]) == 20
2017-02-23 22:51:37 +00:00
update!(dm, 0.1)
@test dm(xs) d(xs)
@test dm(xs) d(xs)
end
2017-03-31 11:39:23 +00:00
@testset "Native interface" begin
2017-02-23 22:28:18 +00:00
f = mx.FeedForward(Chain(d, softmax))
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
flatten, Affine(1587, 10), softmax)
f = mx.FeedForward(m)
2017-02-23 22:51:37 +00:00
# TODO: test run
2017-02-23 22:28:18 +00:00
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
end
@testset "Stack Traces" begin
model = TLP(Affine(10, 20), Affine(21, 15))
info("The following warning is normal")
2017-03-08 21:41:13 +00:00
dm = mxnet(model)
2017-04-18 20:04:21 +00:00
e = try dm(rand(1, 10))
2017-02-23 22:28:18 +00:00
catch e e end
@test isa(e, DataFlow.Interpreter.Exception)
@test e.trace[1].func == Symbol("Flux.Affine")
@test e.trace[2].func == :TLP
end
end