diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 19fe53a1..6a09d44f 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -124,8 +124,8 @@ Flux.update!(m::Model, η) = (update!(m.last, η); m) using Flux: Stateful, SeqModel -mxnet(m::Stateful) = Stateful(mxnet(m.model), m.states, m.istate, m.ostate) -mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps) +mxnet(m::Stateful, a...) = Stateful(mxnet(m.model, a...), m.states, m.istate, m.ostate) +mxnet(m::SeqModel, a...) = SeqModel(mxnet(m.model, a...), m.steps) # MX FeedForward interface