unrolled type

This commit is contained in:
Mike J Innes 2016-10-26 11:57:03 +01:00
parent 2a58b23085
commit 823792bc19
1 changed files with 12 additions and 2 deletions

View File

@ -59,7 +59,7 @@ function break!(g::IVertex)
cse(group(group(loops...), g)), defaults
end
function unroll(model, n)
function unrollgraph(model, n)
graph, defaults = break!(atomise(model))
outputs = [spliceinputs(graph, group(map(constant, defaults)...), inputnode(1))]
detuple(outputs[end])
@ -71,6 +71,16 @@ function unroll(model, n)
@> group(state, group(outputs...)) detuple
end
type Unrolled <: Model
model
graph::IVertex{Any}
steps::Int
end
graph(u::Unrolled) = u.graph
unroll(model, n) = Unrolled(model, unrollgraph(model, n), n)
@net type Recurrent
Wxh; Whh; Why
hidden
@ -86,4 +96,4 @@ Recurrent() = Recurrent((rand(i,i) for i = 1:4)...)
# syntax(x) = syntax(Flow.dl(x), bindconst = true)
# r = Recurrent()
# unroll(r,10) |> cse |> syntax |> prettify |> display
# unrollgraph(r,5) |> cse |> syntax |> prettify |> display