From 69b24bfa9b11a67121c38215cc4ded48bb5ed8ce Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 27 Feb 2017 22:06:38 +0000 Subject: [PATCH] stateless can be a postprocess --- src/compiler/loops.jl | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index 4553b7f5..3bb4ff56 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -60,12 +60,12 @@ end hiddeninput(n) = vertex(Split(n), inputnode(1)) -function create_steps(v::IVertex, n; seq = true, stateful = true) - [(stateful ? bumpinputs : copy)(seq ? spliceinputs(v, hiddeninput(i)) : v) for i = 1:n] +function create_steps(v::IVertex, n; seq = true) + [copy(seq ? spliceinputs(v, hiddeninput(i)) : v) for i = 1:n] end -function getvar(n, step, steps, offset, default; stateful = true) - if stateful && step < 1 +function getvar(n, step, steps, offset, default) + if step < 1 hiddeninput(sum(offset[1:n-1]) + 1 - step) elseif step ∉ 1:length(steps) constant(default[n]) @@ -84,30 +84,26 @@ function stateout(steps, offset, default) group(outs...), defaults end -function unrollgraph(v::IVertex, n; seq = true, stateful = true) # Input: (hidden1, hidden2, ...), (x1, x2, ...) # Output: (hidden1, hidden2, ...), (y1, y2, ...) # If `seq` is false, takes a single `x` and uses this for each iteration. # If `stateful` is false there are no hidden inputs or outputs. +function unrollgraph(v::IVertex, n; seq = true) state, offset, default = collect_state(v) v = group(group(state...), v) - steps = create_steps(v, n, seq = seq, stateful = stateful) + steps = create_steps(v, n, seq = seq) for i = 1:n vars = inputs(steps[i][1]) postwalk!(steps[i]) do v value(v) isa Offset || return v varid = findfirst(vars,v[1]) - getvar(varid, value(v).n + i, steps, offset, default, stateful = stateful) + getvar(varid, value(v).n + i, steps, offset, default) end end out = group(map(x->x[2], steps)...) - if stateful - state, defaults = stateout(steps, offset, default) - group(state,out), map(Flux.state, defaults) - else - out, [] - end + state, defaults = stateout(steps, offset, default) + group(state,out), map(Flux.state, defaults) end unrollgraph(m, n; kws...) = unrollgraph(atomise(m), n; kws...) @@ -125,9 +121,9 @@ end graph(u::Unrolled) = u.graph -function unroll(model, n; seq = true, stateful = true) - graph, state = unrollgraph(model, n; seq = seq, stateful = stateful) - seq || stateful ? Unrolled(model, graph, state, stateful, n) : Capacitor(graph) +function unroll(model, n; seq = true) + graph, state = unrollgraph(model, n; seq = seq) + Unrolled(model, graph, state, true, n) end function unroll1(model)