diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 9d7a310e..304ba170 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -133,7 +133,7 @@ Flux.update!(m::Model, η) = (update!(m.last, η); m) using Flux: Stateful, SeqModel -mxnet(m::Stateful) = Stateful(mxnet(m.model), m.state) +mxnet(m::Stateful) = Stateful(mxnet(m.model), copy(m.state)) mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps) # MX FeedForward interface diff --git a/test/backend/mxnet.jl b/test/backend/mxnet.jl index e79c840a..e27bbea6 100644 --- a/test/backend/mxnet.jl +++ b/test/backend/mxnet.jl @@ -13,6 +13,13 @@ m = Multi(20, 15) mm = mxnet(m) @test all(isapprox.(mm(xs, ys), m(xs, ys))) +@testset "Recurrence" begin + seq = Seq(rand(10) for i = 1:3) + r = unroll(Recurrent(10, 5), 3) + rm = mxnet(r) + @test r(seq) ≈ rm(seq) +end + @testset "Backward Pass" begin d′ = deepcopy(d) @test dm(xs) ≈ d(xs) @@ -26,7 +33,7 @@ mm = mxnet(m) @test dm(xs) ≉ d′(xs) end -@testset "FeedForward interface" begin +@testset "Native interface" begin f = mx.FeedForward(Chain(d, softmax)) @test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]