unroll multiple inputs
This commit is contained in:
parent
8306ed2ed7
commit
7d2a34b55d
@ -9,9 +9,9 @@ end
|
||||
|
||||
Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss))
|
||||
|
||||
function (m::Stateful)(x)
|
||||
function (m::Stateful)(xs...)
|
||||
m.istate = m.ostate
|
||||
state, y = m.model((m.istate...,), x)
|
||||
state, y = m.model((m.istate...,), xs...)
|
||||
m.ostate = collect(state)
|
||||
return y
|
||||
end
|
||||
@ -118,13 +118,14 @@ function collect_state(v::IVertex)
|
||||
return state, offset, default
|
||||
end
|
||||
|
||||
hiddeninput(n) = vertex(Split(n), inputnode(1))
|
||||
hiddeninput(n, t) = vertex(Split(t), inputnode(n))
|
||||
|
||||
create_steps(v::IVertex, n) = [bumpinputs(spliceinputs(v, hiddeninput(i))) for i = 1:n]
|
||||
# TODO: nicer way to do this.
|
||||
create_steps(v::IVertex, n) = [bumpinputs(spliceinputs(v, [hiddeninput(n, t) for n = 1:graphinputs(v)]...)) for t = 1:n]
|
||||
|
||||
function getvar(n, step, steps, offset, default)
|
||||
if step < 1
|
||||
hiddeninput(sum(offset[1:n-1]) + 1 - step)
|
||||
hiddeninput(1, sum(offset[1:n-1]) + 1 - step)
|
||||
elseif step ∉ 1:length(steps)
|
||||
constant(default[n])
|
||||
else
|
||||
@ -182,7 +183,7 @@ stateless(s::SeqModel) = SeqModel(stateless(s.model), s.steps)
|
||||
function unseqin(v::IVertex)
|
||||
prewalk(v) do v
|
||||
# TODO: inputidx function
|
||||
isa(value(v), Split) && DataFlow.isinput(v[1]) && value(v[1]).n == 2 ? v[1] : v
|
||||
isa(value(v), Split) && DataFlow.isinput(v[1]) && value(v[1]).n > 1 ? v[1] : v
|
||||
end
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user