unrolled type
This commit is contained in:
parent
2a58b23085
commit
823792bc19
@ -59,7 +59,7 @@ function break!(g::IVertex)
|
|||||||
cse(group(group(loops...), g)), defaults
|
cse(group(group(loops...), g)), defaults
|
||||||
end
|
end
|
||||||
|
|
||||||
function unroll(model, n)
|
function unrollgraph(model, n)
|
||||||
graph, defaults = break!(atomise(model))
|
graph, defaults = break!(atomise(model))
|
||||||
outputs = [spliceinputs(graph, group(map(constant, defaults)...), inputnode(1))]
|
outputs = [spliceinputs(graph, group(map(constant, defaults)...), inputnode(1))]
|
||||||
detuple(outputs[end])
|
detuple(outputs[end])
|
||||||
@ -71,6 +71,16 @@ function unroll(model, n)
|
|||||||
@> group(state, group(outputs...)) detuple
|
@> group(state, group(outputs...)) detuple
|
||||||
end
|
end
|
||||||
|
|
||||||
|
type Unrolled <: Model
|
||||||
|
model
|
||||||
|
graph::IVertex{Any}
|
||||||
|
steps::Int
|
||||||
|
end
|
||||||
|
|
||||||
|
graph(u::Unrolled) = u.graph
|
||||||
|
|
||||||
|
unroll(model, n) = Unrolled(model, unrollgraph(model, n), n)
|
||||||
|
|
||||||
@net type Recurrent
|
@net type Recurrent
|
||||||
Wxh; Whh; Why
|
Wxh; Whh; Why
|
||||||
hidden
|
hidden
|
||||||
@ -86,4 +96,4 @@ Recurrent() = Recurrent((rand(i,i) for i = 1:4)...)
|
|||||||
# syntax′(x) = syntax(Flow.dl(x), bindconst = true)
|
# syntax′(x) = syntax(Flow.dl(x), bindconst = true)
|
||||||
|
|
||||||
# r = Recurrent()
|
# r = Recurrent()
|
||||||
# unroll(r,10) |> cse |> syntax′ |> prettify |> display
|
# unrollgraph(r,5) |> cse |> syntax′ |> prettify |> display
|
||||||
|
Loading…
Reference in New Issue
Block a user