preserve original param states
This commit is contained in:
parent
af65b9200c
commit
2b85c76785
@ -100,7 +100,7 @@ function unrollgraph(v::IVertex, n)
|
||||
end
|
||||
out = group(map(x->x[2], steps)...)
|
||||
state, defaults = stateout(steps, offset, default)
|
||||
group(state,out), map(Flux.state, defaults)
|
||||
group(state,out), defaults
|
||||
end
|
||||
|
||||
unrollgraph(m, n; kws...) = unrollgraph(atomise(m), n; kws...)
|
||||
|
@ -42,11 +42,12 @@ end
|
||||
|
||||
mutable struct Stateful <: Model
|
||||
model
|
||||
states::Vector{Any}
|
||||
istate::Vector{Any}
|
||||
ostate::Vector{Any}
|
||||
end
|
||||
|
||||
Stateful(model, state) = Stateful(model, state, state)
|
||||
Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss))
|
||||
|
||||
function (m::Stateful)(x)
|
||||
m.istate = m.ostate
|
||||
|
Loading…
Reference in New Issue
Block a user