From 8f911cc31ebc8bb9a420ec1b7dc3d4690e98efaf Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 27 Feb 2017 22:52:08 +0000 Subject: [PATCH] so can unseq --- src/compiler/loops.jl | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index 3bb4ff56..0ee1c865 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -60,9 +60,7 @@ end hiddeninput(n) = vertex(Split(n), inputnode(1)) -function create_steps(v::IVertex, n; seq = true) - [copy(seq ? spliceinputs(v, hiddeninput(i)) : v) for i = 1:n] -end +create_steps(v::IVertex, n) = [bumpinputs(spliceinputs(v, hiddeninput(i))) for i = 1:n] function getvar(n, step, steps, offset, default) if step < 1 @@ -86,13 +84,11 @@ end # Input: (hidden1, hidden2, ...), (x1, x2, ...) # Output: (hidden1, hidden2, ...), (y1, y2, ...) -# If `seq` is false, takes a single `x` and uses this for each iteration. -# If `stateful` is false there are no hidden inputs or outputs. -function unrollgraph(v::IVertex, n; seq = true) +function unrollgraph(v::IVertex, n) state, offset, default = collect_state(v) v = group(group(state...), v) - steps = create_steps(v, n, seq = seq) + steps = create_steps(v, n) for i = 1:n vars = inputs(steps[i][1]) postwalk!(steps[i]) do v @@ -121,13 +117,21 @@ end graph(u::Unrolled) = u.graph -function unroll(model, n; seq = true) - graph, state = unrollgraph(model, n; seq = seq) +function unroll(model, n) + graph, state = unrollgraph(model, n) Unrolled(model, graph, state, true, n) end +function unseqinput(v::IVertex) + prewalk(v) do v + # TODO: inputidx function + isa(value(v), Split) && DataFlow.isinput(v[1]) && value(v[1]).n == 2 ? v[1] : v + end +end + function unroll1(model) - graph, state = unrollgraph(model, 1; seq = false) + graph, state = unrollgraph(model, 1) + graph = unseqinput(graph) graph = group(graph[1], map(x->x[1], inputs(graph)[2:end])...) Unrolled(model, graph, state, false, 1) end