basic unrolling
This commit is contained in:
parent
19b5e8bd21
commit
cd968af228
@ -20,9 +20,22 @@ bumpinput(x) = x
|
|||||||
|
|
||||||
bumpinputs(v::IVertex) = mapconst(bumpinput, v)
|
bumpinputs(v::IVertex) = mapconst(bumpinput, v)
|
||||||
|
|
||||||
|
function unroll(delay::IVertex)
|
||||||
|
prewalk(delay[1]) do v
|
||||||
|
isa(value(v), Delay) ? constant(ModelInput(1)) : v
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
function break!(model)
|
function break!(model)
|
||||||
iscyclic(graph(model)) || return model
|
iscyclic(graph(model)) || return model
|
||||||
bumpinputs(graph(model))
|
g = bumpinputs(graph(model))
|
||||||
|
loops = []
|
||||||
|
g = prewalk(g) do v
|
||||||
|
isa(value(v), Delay) || return v
|
||||||
|
push!(loops, unroll(v))
|
||||||
|
constant(ModelInput(1))
|
||||||
|
end
|
||||||
|
cse(vertex(tuple, loops..., g))
|
||||||
end
|
end
|
||||||
|
|
||||||
# r = Recurrent(784, 10, 50)
|
# r = Recurrent(784, 10, 50)
|
||||||
|
Loading…
Reference in New Issue
Block a user