From 21a3b952605aa9baf6cb35f2ed44ef6eedebd76b Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 8 Nov 2016 00:06:45 +0000 Subject: [PATCH] mostly recover old behaviour --- src/compiler/loops.jl | 89 +++++++++++++++++++++++++++---------------- 1 file changed, 56 insertions(+), 33 deletions(-) diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index 55da8dd5..66084a54 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -32,39 +32,53 @@ function atomise(model) end end -# hiddeninput(n) = vertex(Split(n), inputnode(1)) -# -# function unroll!(delay::IVertex, n) -# prewalk!(delay[1]) do v -# v === delay ? hiddeninput(n) : v -# end -# end -# -# function break!(g::IVertex) -# g = bumpinputs(g) -# loops = [] -# defaults = [] -# g = prewalk!(g) do v -# isa(value(v), Offset) || return v -# n = length(loops)+1 -# push!(loops, unroll!(v, n)) -# push!(defaults, get(value(v).default)) -# hiddeninput(n) -# end -# cse(group(group(loops...), g)), defaults -# end -# -# function unrollgraph(model, n) -# graph, defaults = break!(atomise(model)) -# outputs = [spliceinputs(graph, group([constant(splitnode(inputnode(1),i)) for i = 1:length(defaults)]...), -# splitnode(inputnode(2), 1))] -# for i = 2:n -# push!(outputs, spliceinputs(graph, outputs[end][1], splitnode(inputnode(2), i))) -# end -# state = outputs[end][1] -# outputs = map(x -> x[2], outputs) -# (@> group(state, group(outputs...)) detuple), map(x->x.x, defaults) -# end +function collect_state(v::IVertex) + state = typeof(v)[] + offset = Int[] + default = Param[] + prewalk!(v) do v + isa(value(v), Offset) || return v + if (i = findfirst(state, v[1])) == 0 + push!(state, v[1]) + push!(offset, value(v).n) + push!(default, get(value(v).default)) + else + offset[i] = min(offset[i], value(v).n) + end + v + end + return state, offset, default +end + +hiddeninput(n) = vertex(Split(n), inputnode(1)) + +function create_steps(v::IVertex, n) + [bumpinputs(spliceinputs(v, hiddeninput(i))) for i = 1:n] +end + +function unrollgraph(v::IVertex, n) + state, offset, default = collect_state(v) + v = group(group(state...), v) + steps = create_steps(v, n) + for i = 1:n + vars = inputs(steps[i][1]) + prewalk!(steps[i]) do v + isa(value(v), Offset) || return v + stepid = value(v).n + i + varid = findfirst(vars,v[1]) + if stepid ∈ 1:n + steps[stepid][1,varid] + elseif stepid < 1 + vertex(:input, constant(varid)) + elseif stepid > n + constant(:output, constant(varid)) + end + end + end + group(steps[end][1],group(map(x->x[2], steps)...)) +end + +unrollgraph(atomise(Chain(r,r)), 5) |> detuple |> syntax |> prettify type Unrolled <: Model model @@ -76,3 +90,12 @@ end graph(u::Unrolled) = u.graph unroll(model, n) = Unrolled(model, unrollgraph(model, n)..., n) + +@net type Recurrent + y + function (x) + y = σ(x, y{-1}) + end +end + +r = Recurrent(rand(5))