unroll multiple inputs

This commit is contained in:
Mike J Innes 2017-06-17 19:21:39 -07:00
parent 8306ed2ed7
commit 7d2a34b55d

View File

@ -9,9 +9,9 @@ end
Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss)) Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss))
function (m::Stateful)(x) function (m::Stateful)(xs...)
m.istate = m.ostate m.istate = m.ostate
state, y = m.model((m.istate...,), x) state, y = m.model((m.istate...,), xs...)
m.ostate = collect(state) m.ostate = collect(state)
return y return y
end end
@ -118,13 +118,14 @@ function collect_state(v::IVertex)
return state, offset, default return state, offset, default
end end
hiddeninput(n) = vertex(Split(n), inputnode(1)) hiddeninput(n, t) = vertex(Split(t), inputnode(n))
create_steps(v::IVertex, n) = [bumpinputs(spliceinputs(v, hiddeninput(i))) for i = 1:n] # TODO: nicer way to do this.
create_steps(v::IVertex, n) = [bumpinputs(spliceinputs(v, [hiddeninput(n, t) for n = 1:graphinputs(v)]...)) for t = 1:n]
function getvar(n, step, steps, offset, default) function getvar(n, step, steps, offset, default)
if step < 1 if step < 1
hiddeninput(sum(offset[1:n-1]) + 1 - step) hiddeninput(1, sum(offset[1:n-1]) + 1 - step)
elseif step 1:length(steps) elseif step 1:length(steps)
constant(default[n]) constant(default[n])
else else
@ -182,7 +183,7 @@ stateless(s::SeqModel) = SeqModel(stateless(s.model), s.steps)
function unseqin(v::IVertex) function unseqin(v::IVertex)
prewalk(v) do v prewalk(v) do v
# TODO: inputidx function # TODO: inputidx function
isa(value(v), Split) && DataFlow.isinput(v[1]) && value(v[1]).n == 2 ? v[1] : v isa(value(v), Split) && DataFlow.isinput(v[1]) && value(v[1]).n > 1 ? v[1] : v
end end
end end