preserve original param states
This commit is contained in:
parent
af65b9200c
commit
2b85c76785
@ -100,7 +100,7 @@ function unrollgraph(v::IVertex, n)
|
|||||||
end
|
end
|
||||||
out = group(map(x->x[2], steps)...)
|
out = group(map(x->x[2], steps)...)
|
||||||
state, defaults = stateout(steps, offset, default)
|
state, defaults = stateout(steps, offset, default)
|
||||||
group(state,out), map(Flux.state, defaults)
|
group(state,out), defaults
|
||||||
end
|
end
|
||||||
|
|
||||||
unrollgraph(m, n; kws...) = unrollgraph(atomise(m), n; kws...)
|
unrollgraph(m, n; kws...) = unrollgraph(atomise(m), n; kws...)
|
||||||
|
@ -42,11 +42,12 @@ end
|
|||||||
|
|
||||||
mutable struct Stateful <: Model
|
mutable struct Stateful <: Model
|
||||||
model
|
model
|
||||||
|
states::Vector{Any}
|
||||||
istate::Vector{Any}
|
istate::Vector{Any}
|
||||||
ostate::Vector{Any}
|
ostate::Vector{Any}
|
||||||
end
|
end
|
||||||
|
|
||||||
Stateful(model, state) = Stateful(model, state, state)
|
Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss))
|
||||||
|
|
||||||
function (m::Stateful)(x)
|
function (m::Stateful)(x)
|
||||||
m.istate = m.ostate
|
m.istate = m.ostate
|
||||||
|
Loading…
Reference in New Issue
Block a user