Flux.jl/test/backend/mxnet.jl
2017-03-31 12:39:23 +01:00

60 lines
1.3 KiB
Julia
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using MXNet
Flux.loadmx()
@testset "MXNet" begin
xs, ys = rand(20), rand(20)
d = Affine(20, 10)
dm = mxnet(d)
@test d(xs) dm(xs)
m = Multi(20, 15)
mm = mxnet(m)
@test all(isapprox.(mm(xs, ys), m(xs, ys)))
@testset "Recurrence" begin
seq = Seq(rand(10) for i = 1:3)
r = unroll(Recurrent(10, 5), 3)
rm = mxnet(r)
@test r(seq) rm(seq)
end
@testset "Backward Pass" begin
d = deepcopy(d)
@test dm(xs) d(xs)
@test dm(xs) d(xs)
Δ = back!(dm, randn(10), xs)
@test length(Δ[1]) == 20
update!(dm, 0.1)
@test dm(xs) d(xs)
@test dm(xs) d(xs)
end
@testset "Native interface" begin
f = mx.FeedForward(Chain(d, softmax))
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
flatten, Affine(1587, 10), softmax)
f = mx.FeedForward(m)
# TODO: test run
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
end
@testset "Stack Traces" begin
model = TLP(Affine(10, 20), Affine(21, 15))
info("The following warning is normal")
dm = mxnet(model)
e = try dm(rand(10))
catch e e end
@test isa(e, DataFlow.Interpreter.Exception)
@test e.trace[1].func == Symbol("Flux.Affine")
@test e.trace[2].func == :TLP
end
end