diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 11c6db60..2bb33202 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -133,7 +133,7 @@ Flux.update!(m::Model, η) = (update!(m.last, η); m) using Flux: Stateful, SeqModel -mxnet(m::Stateful) = Stateful(mxnet(m.model), copy(m.state)) +mxnet(m::Stateful) = Stateful(mxnet(m.model), m.istate, m.ostate) mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps) # MX FeedForward interface diff --git a/src/model.jl b/src/model.jl index 61b4fdc0..371e7e2f 100644 --- a/src/model.jl +++ b/src/model.jl @@ -127,17 +127,25 @@ graph(cap::Capacitor) = cap.graph # Recurrent Models -struct Stateful <: Model +mutable struct Stateful <: Model model - state::Vector{Any} + istate::Vector{Any} + ostate::Vector{Any} end +Stateful(model, state) = Stateful(model, state, state) + function (m::Stateful)(x) - state, y = runmodel(m.model, (m.state...,), x) - m.state .= state + m.istate = m.ostate + state, y = runmodel(m.model, (m.istate...,), x) + m.ostate = collect(state) return y end +function back!(m::Stateful, Δ, x) + back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end] +end + stateless(m) = m stateless(m::Stateful) = m.model @@ -157,5 +165,4 @@ function (m::SeqModel)(x) end end -(m::SeqModel)(x::AbstractArray) = stack(m((unstack(x, 2)...,)), 2) -(m::SeqModel)(x::BatchSeq) = rebatchseq(m(rawbatch(x))) +back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)