basic unrolling
This commit is contained in:
parent
1fde7b4615
commit
14e4117837
|
@ -56,13 +56,19 @@ function break!(g::IVertex)
|
|||
cse(vertex(tuple, vertex(tuple, loops...), g)), defaults
|
||||
end
|
||||
|
||||
# r = Recurrent(10, 10)
|
||||
# r = Chain(Recurrent(10,10), Dense(10,10))
|
||||
# r = Chain(Recurrent(10,10),Recurrent(10,10))
|
||||
#
|
||||
# atomise(r) |> syntax |> prettify |> display
|
||||
#
|
||||
# break!(atomise(r)) |> syntax |> prettify |> display
|
||||
function unroll(model, n)
|
||||
graph, defaults = break!(atomise(model))
|
||||
outputs = [spliceinputs(graph, vertex(tuple, map(constant, defaults)...), constant(ModelInput(1)))]
|
||||
for i = 2:n
|
||||
push!(outputs, spliceinputs(graph, outputs[end][1], constant(ModelInput(i))))
|
||||
end
|
||||
state = outputs[end][1]
|
||||
outputs = map(x -> x[2], outputs)
|
||||
vertex(tuple, state, vertex(tuple, outputs...))
|
||||
end
|
||||
|
||||
# r = Chain(Recurrent(30,10), Recurrent(10, 20))
|
||||
# unroll(r, 1) |> syntax |> prettify |> display
|
||||
|
||||
@net type Recurrent
|
||||
Wx; Wh; B
|
||||
|
|
Loading…
Reference in New Issue