diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index a48c541c..0f3bd6a1 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -20,9 +20,22 @@ bumpinput(x) = x bumpinputs(v::IVertex) = mapconst(bumpinput, v) +function unroll(delay::IVertex) + prewalk(delay[1]) do v + isa(value(v), Delay) ? constant(ModelInput(1)) : v + end +end + function break!(model) iscyclic(graph(model)) || return model - bumpinputs(graph(model)) + g = bumpinputs(graph(model)) + loops = [] + g = prewalk(g) do v + isa(value(v), Delay) || return v + push!(loops, unroll(v)) + constant(ModelInput(1)) + end + cse(vertex(tuple, loops..., g)) end # r = Recurrent(784, 10, 50)