From 7d2a34b55d543bb2d83542228534fdbffe697aa7 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Sat, 17 Jun 2017 19:21:39 -0700 Subject: [PATCH] unroll multiple inputs --- src/compiler/loops.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index 0599a5db..72b882f3 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -9,9 +9,9 @@ end Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss)) -function (m::Stateful)(x) +function (m::Stateful)(xs...) m.istate = m.ostate - state, y = m.model((m.istate...,), x) + state, y = m.model((m.istate...,), xs...) m.ostate = collect(state) return y end @@ -118,13 +118,14 @@ function collect_state(v::IVertex) return state, offset, default end -hiddeninput(n) = vertex(Split(n), inputnode(1)) +hiddeninput(n, t) = vertex(Split(t), inputnode(n)) -create_steps(v::IVertex, n) = [bumpinputs(spliceinputs(v, hiddeninput(i))) for i = 1:n] +# TODO: nicer way to do this. +create_steps(v::IVertex, n) = [bumpinputs(spliceinputs(v, [hiddeninput(n, t) for n = 1:graphinputs(v)]...)) for t = 1:n] function getvar(n, step, steps, offset, default) if step < 1 - hiddeninput(sum(offset[1:n-1]) + 1 - step) + hiddeninput(1, sum(offset[1:n-1]) + 1 - step) elseif step ∉ 1:length(steps) constant(default[n]) else @@ -182,7 +183,7 @@ stateless(s::SeqModel) = SeqModel(stateless(s.model), s.steps) function unseqin(v::IVertex) prewalk(v) do v # TODO: inputidx function - isa(value(v), Split) && DataFlow.isinput(v[1]) && value(v[1]).n == 2 ? v[1] : v + isa(value(v), Split) && DataFlow.isinput(v[1]) && value(v[1]).n > 1 ? v[1] : v end end