preserve original param states

This commit is contained in:
Mike J Innes 2017-06-02 15:46:24 +01:00
parent af65b9200c
commit 2b85c76785
2 changed files with 3 additions and 2 deletions

View File

@ -100,7 +100,7 @@ function unrollgraph(v::IVertex, n)
end end
out = group(map(x->x[2], steps)...) out = group(map(x->x[2], steps)...)
state, defaults = stateout(steps, offset, default) state, defaults = stateout(steps, offset, default)
group(state,out), map(Flux.state, defaults) group(state,out), defaults
end end
unrollgraph(m, n; kws...) = unrollgraph(atomise(m), n; kws...) unrollgraph(m, n; kws...) = unrollgraph(atomise(m), n; kws...)

View File

@ -42,11 +42,12 @@ end
mutable struct Stateful <: Model mutable struct Stateful <: Model
model model
states::Vector{Any}
istate::Vector{Any} istate::Vector{Any}
ostate::Vector{Any} ostate::Vector{Any}
end end
Stateful(model, state) = Stateful(model, state, state) Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss))
function (m::Stateful)(x) function (m::Stateful)(x)
m.istate = m.ostate m.istate = m.ostate