diff --git a/src/Flux.jl b/src/Flux.jl index c72b3171..671581aa 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -6,7 +6,7 @@ using MacroTools, Lazy, DataFlow, Juno using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk, iscyclic, Constant, constant, isconstant, group, Split, splitnode, detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode, - spliceinputs, bumpinputs, Line, Frame, applylines + spliceinputs, bumpinputs, Line, Frame, applylines, graphinputs using DataFlow.Interpreter using Juno: Tree, Row diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index 6b38bf13..52d54806 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -136,6 +136,15 @@ function unroll(model, n) SeqModel(Stateful(Capacitor(graph), state), n) end +function stateless(s::Stateful) + v = graph(s.model) + v = spliceinputs(v, group(constant.(s.states)...), + [inputnode(i) for i = 1:graphinputs(v)-1]...) + Capacitor(v) +end + +stateless(s::SeqModel) = SeqModel(stateless(s.model), s.steps) + function unseqin(v::IVertex) prewalk(v) do v # TODO: inputidx function