diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index eba9378c..3785d1a2 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -100,7 +100,7 @@ function unrollgraph(v::IVertex, n) end out = group(map(x->x[2], steps)...) state, defaults = stateout(steps, offset, default) - group(state,out), map(Flux.state, defaults) + group(state,out), defaults end unrollgraph(m, n; kws...) = unrollgraph(atomise(m), n; kws...) diff --git a/src/layers/control.jl b/src/layers/control.jl index f30af30d..7cc907f7 100644 --- a/src/layers/control.jl +++ b/src/layers/control.jl @@ -42,11 +42,12 @@ end mutable struct Stateful <: Model model + states::Vector{Any} istate::Vector{Any} ostate::Vector{Any} end -Stateful(model, state) = Stateful(model, state, state) +Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss)) function (m::Stateful)(x) m.istate = m.ostate