diff --git a/src/compiler/graph.jl b/src/compiler/graph.jl index 5ee0add9..3a9b02f9 100644 --- a/src/compiler/graph.jl +++ b/src/compiler/graph.jl @@ -4,9 +4,11 @@ splitnode(v, n) = vertex(Split(n), v) inputnode(n) = splitnode(constant(ModelInput()), n) +isinput(v::IVertex) = isa(value(v), Split) && value(v[1]) == Constant(ModelInput()) + function bumpinputs(v::IVertex) prewalk(v) do v - isa(value(v), Split) && value(v[1]) == Constant(ModelInput()) ? + isinput(v) ? inputnode(value(v).n + 1) : v end @@ -21,6 +23,15 @@ end spliceinputs(v::IVertex, inputs::Vertex...) = spliceinput(v, vertex(Group(), inputs...)) +function ninputs(v::IVertex) + n = 0 + prewalk(v) do v + isinput(v) && (n = max(n, value(v).n)) + v + end + return n +end + function detuple(v::IVertex) postwalk(v) do v if isa(value(v), Split) && isa(value(v[1]), Group)