ctx methods for seq models
This commit is contained in:
parent
020ae616cc
commit
1cc8100456
@ -124,8 +124,8 @@ Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
|||||||
|
|
||||||
using Flux: Stateful, SeqModel
|
using Flux: Stateful, SeqModel
|
||||||
|
|
||||||
mxnet(m::Stateful) = Stateful(mxnet(m.model), m.states, m.istate, m.ostate)
|
mxnet(m::Stateful, a...) = Stateful(mxnet(m.model, a...), m.states, m.istate, m.ostate)
|
||||||
mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps)
|
mxnet(m::SeqModel, a...) = SeqModel(mxnet(m.model, a...), m.steps)
|
||||||
|
|
||||||
# MX FeedForward interface
|
# MX FeedForward interface
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user