unroll multiple inputs

This commit is contained in:
Mike J Innes 2017-06-17 19:21:39 -07:00
parent 8306ed2ed7
commit 7d2a34b55d

View File

@ -9,9 +9,9 @@ end
Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss))
function (m::Stateful)(x)
function (m::Stateful)(xs...)
m.istate = m.ostate
state, y = m.model((m.istate...,), x)
state, y = m.model((m.istate...,), xs...)
m.ostate = collect(state)
return y
end
@ -118,13 +118,14 @@ function collect_state(v::IVertex)
return state, offset, default
end
hiddeninput(n) = vertex(Split(n), inputnode(1))
hiddeninput(n, t) = vertex(Split(t), inputnode(n))
create_steps(v::IVertex, n) = [bumpinputs(spliceinputs(v, hiddeninput(i))) for i = 1:n]
# TODO: nicer way to do this.
create_steps(v::IVertex, n) = [bumpinputs(spliceinputs(v, [hiddeninput(n, t) for n = 1:graphinputs(v)]...)) for t = 1:n]
function getvar(n, step, steps, offset, default)
if step < 1
hiddeninput(sum(offset[1:n-1]) + 1 - step)
hiddeninput(1, sum(offset[1:n-1]) + 1 - step)
elseif step 1:length(steps)
constant(default[n])
else
@ -182,7 +183,7 @@ stateless(s::SeqModel) = SeqModel(stateless(s.model), s.steps)
function unseqin(v::IVertex)
prewalk(v) do v
# TODO: inputidx function
isa(value(v), Split) && DataFlow.isinput(v[1]) && value(v[1]).n == 2 ? v[1] : v
isa(value(v), Split) && DataFlow.isinput(v[1]) && value(v[1]).n > 1 ? v[1] : v
end
end