unrolled type
This commit is contained in:
parent
2a58b23085
commit
823792bc19
|
@ -59,7 +59,7 @@ function break!(g::IVertex)
|
|||
cse(group(group(loops...), g)), defaults
|
||||
end
|
||||
|
||||
function unroll(model, n)
|
||||
function unrollgraph(model, n)
|
||||
graph, defaults = break!(atomise(model))
|
||||
outputs = [spliceinputs(graph, group(map(constant, defaults)...), inputnode(1))]
|
||||
detuple(outputs[end])
|
||||
|
@ -71,6 +71,16 @@ function unroll(model, n)
|
|||
@> group(state, group(outputs...)) detuple
|
||||
end
|
||||
|
||||
type Unrolled <: Model
|
||||
model
|
||||
graph::IVertex{Any}
|
||||
steps::Int
|
||||
end
|
||||
|
||||
graph(u::Unrolled) = u.graph
|
||||
|
||||
unroll(model, n) = Unrolled(model, unrollgraph(model, n), n)
|
||||
|
||||
@net type Recurrent
|
||||
Wxh; Whh; Why
|
||||
hidden
|
||||
|
@ -86,4 +96,4 @@ Recurrent() = Recurrent((rand(i,i) for i = 1:4)...)
|
|||
# syntax′(x) = syntax(Flow.dl(x), bindconst = true)
|
||||
|
||||
# r = Recurrent()
|
||||
# unroll(r,10) |> cse |> syntax′ |> prettify |> display
|
||||
# unrollgraph(r,5) |> cse |> syntax′ |> prettify |> display
|
||||
|
|
Loading…
Reference in New Issue