basic unrolling
This commit is contained in:
parent
19b5e8bd21
commit
cd968af228
|
@ -20,9 +20,22 @@ bumpinput(x) = x
|
|||
|
||||
bumpinputs(v::IVertex) = mapconst(bumpinput, v)
|
||||
|
||||
function unroll(delay::IVertex)
|
||||
prewalk(delay[1]) do v
|
||||
isa(value(v), Delay) ? constant(ModelInput(1)) : v
|
||||
end
|
||||
end
|
||||
|
||||
function break!(model)
|
||||
iscyclic(graph(model)) || return model
|
||||
bumpinputs(graph(model))
|
||||
g = bumpinputs(graph(model))
|
||||
loops = []
|
||||
g = prewalk(g) do v
|
||||
isa(value(v), Delay) || return v
|
||||
push!(loops, unroll(v))
|
||||
constant(ModelInput(1))
|
||||
end
|
||||
cse(vertex(tuple, loops..., g))
|
||||
end
|
||||
|
||||
# r = Recurrent(784, 10, 50)
|
||||
|
|
Loading…
Reference in New Issue