basic unrolling

This commit is contained in:
Mike J Innes 2016-10-25 21:10:04 +01:00
parent 1fde7b4615
commit 14e4117837
1 changed files with 13 additions and 7 deletions

View File

@ -56,13 +56,19 @@ function break!(g::IVertex)
cse(vertex(tuple, vertex(tuple, loops...), g)), defaults
end
# r = Recurrent(10, 10)
# r = Chain(Recurrent(10,10), Dense(10,10))
# r = Chain(Recurrent(10,10),Recurrent(10,10))
#
# atomise(r) |> syntax |> prettify |> display
#
# break!(atomise(r)) |> syntax |> prettify |> display
function unroll(model, n)
graph, defaults = break!(atomise(model))
outputs = [spliceinputs(graph, vertex(tuple, map(constant, defaults)...), constant(ModelInput(1)))]
for i = 2:n
push!(outputs, spliceinputs(graph, outputs[end][1], constant(ModelInput(i))))
end
state = outputs[end][1]
outputs = map(x -> x[2], outputs)
vertex(tuple, state, vertex(tuple, outputs...))
end
# r = Chain(Recurrent(30,10), Recurrent(10, 20))
# unroll(r, 1) |> syntax |> prettify |> display
@net type Recurrent
Wx; Wh; B