mxnet recurrence test

This commit is contained in:
Mike J Innes 2017-03-31 12:39:23 +01:00
parent b4221f6ea6
commit f8e1f20728
2 changed files with 9 additions and 2 deletions

View File

@ -133,7 +133,7 @@ Flux.update!(m::Model, η) = (update!(m.last, η); m)
using Flux: Stateful, SeqModel
mxnet(m::Stateful) = Stateful(mxnet(m.model), m.state)
mxnet(m::Stateful) = Stateful(mxnet(m.model), copy(m.state))
mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps)
# MX FeedForward interface

View File

@ -13,6 +13,13 @@ m = Multi(20, 15)
mm = mxnet(m)
@test all(isapprox.(mm(xs, ys), m(xs, ys)))
@testset "Recurrence" begin
seq = Seq(rand(10) for i = 1:3)
r = unroll(Recurrent(10, 5), 3)
rm = mxnet(r)
@test r(seq) rm(seq)
end
@testset "Backward Pass" begin
d = deepcopy(d)
@test dm(xs) d(xs)
@ -26,7 +33,7 @@ mm = mxnet(m)
@test dm(xs) d(xs)
end
@testset "FeedForward interface" begin
@testset "Native interface" begin
f = mx.FeedForward(Chain(d, softmax))
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]