ctx methods for seq models
This commit is contained in:
parent
020ae616cc
commit
1cc8100456
@ -124,8 +124,8 @@ Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
||||
|
||||
using Flux: Stateful, SeqModel
|
||||
|
||||
mxnet(m::Stateful) = Stateful(mxnet(m.model), m.states, m.istate, m.ostate)
|
||||
mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps)
|
||||
mxnet(m::Stateful, a...) = Stateful(mxnet(m.model, a...), m.states, m.istate, m.ostate)
|
||||
mxnet(m::SeqModel, a...) = SeqModel(mxnet(m.model, a...), m.steps)
|
||||
|
||||
# MX FeedForward interface
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user