basic unrolling
This commit is contained in:
parent
1fde7b4615
commit
14e4117837
@ -56,13 +56,19 @@ function break!(g::IVertex)
|
|||||||
cse(vertex(tuple, vertex(tuple, loops...), g)), defaults
|
cse(vertex(tuple, vertex(tuple, loops...), g)), defaults
|
||||||
end
|
end
|
||||||
|
|
||||||
# r = Recurrent(10, 10)
|
function unroll(model, n)
|
||||||
# r = Chain(Recurrent(10,10), Dense(10,10))
|
graph, defaults = break!(atomise(model))
|
||||||
# r = Chain(Recurrent(10,10),Recurrent(10,10))
|
outputs = [spliceinputs(graph, vertex(tuple, map(constant, defaults)...), constant(ModelInput(1)))]
|
||||||
#
|
for i = 2:n
|
||||||
# atomise(r) |> syntax |> prettify |> display
|
push!(outputs, spliceinputs(graph, outputs[end][1], constant(ModelInput(i))))
|
||||||
#
|
end
|
||||||
# break!(atomise(r)) |> syntax |> prettify |> display
|
state = outputs[end][1]
|
||||||
|
outputs = map(x -> x[2], outputs)
|
||||||
|
vertex(tuple, state, vertex(tuple, outputs...))
|
||||||
|
end
|
||||||
|
|
||||||
|
# r = Chain(Recurrent(30,10), Recurrent(10, 20))
|
||||||
|
# unroll(r, 1) |> syntax |> prettify |> display
|
||||||
|
|
||||||
@net type Recurrent
|
@net type Recurrent
|
||||||
Wx; Wh; B
|
Wx; Wh; B
|
||||||
|
Loading…
Reference in New Issue
Block a user