basic unrolling

This commit is contained in:
Mike J Innes 2016-08-29 16:23:43 +01:00
parent 19b5e8bd21
commit cd968af228
1 changed files with 14 additions and 1 deletions

View File

@ -20,9 +20,22 @@ bumpinput(x) = x
bumpinputs(v::IVertex) = mapconst(bumpinput, v)
function unroll(delay::IVertex)
prewalk(delay[1]) do v
isa(value(v), Delay) ? constant(ModelInput(1)) : v
end
end
function break!(model)
iscyclic(graph(model)) || return model
bumpinputs(graph(model))
g = bumpinputs(graph(model))
loops = []
g = prewalk(g) do v
isa(value(v), Delay) || return v
push!(loops, unroll(v))
constant(ModelInput(1))
end
cse(vertex(tuple, loops..., g))
end
# r = Recurrent(784, 10, 50)