mxnet recurrence test
This commit is contained in:
parent
b4221f6ea6
commit
f8e1f20728
@ -133,7 +133,7 @@ Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
||||
|
||||
using Flux: Stateful, SeqModel
|
||||
|
||||
mxnet(m::Stateful) = Stateful(mxnet(m.model), m.state)
|
||||
mxnet(m::Stateful) = Stateful(mxnet(m.model), copy(m.state))
|
||||
mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps)
|
||||
|
||||
# MX FeedForward interface
|
||||
|
@ -13,6 +13,13 @@ m = Multi(20, 15)
|
||||
mm = mxnet(m)
|
||||
@test all(isapprox.(mm(xs, ys), m(xs, ys)))
|
||||
|
||||
@testset "Recurrence" begin
|
||||
seq = Seq(rand(10) for i = 1:3)
|
||||
r = unroll(Recurrent(10, 5), 3)
|
||||
rm = mxnet(r)
|
||||
@test r(seq) ≈ rm(seq)
|
||||
end
|
||||
|
||||
@testset "Backward Pass" begin
|
||||
d′ = deepcopy(d)
|
||||
@test dm(xs) ≈ d(xs)
|
||||
@ -26,7 +33,7 @@ mm = mxnet(m)
|
||||
@test dm(xs) ≉ d′(xs)
|
||||
end
|
||||
|
||||
@testset "FeedForward interface" begin
|
||||
@testset "Native interface" begin
|
||||
f = mx.FeedForward(Chain(d, softmax))
|
||||
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user