From 42a7a6ebf6dfe552b8249c6fdd16c6ac58c53c49 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 7 Nov 2016 19:44:51 +0000 Subject: [PATCH] delay -> offset --- src/compiler/code.jl | 4 +- src/compiler/loops.jl | 87 ++++++++++++++++++++----------------------- 2 files changed, 42 insertions(+), 49 deletions(-) diff --git a/src/compiler/code.jl b/src/compiler/code.jl index 675030e8..d4c42846 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -6,7 +6,7 @@ export @net function process_func(ex, params = []) @capture(shortdef(ex), (args__,) -> body_) - body = @> body MacroTools.flatten liftloops!(params) graphm DataFlow.il + body = @> body MacroTools.flatten liftloops(params) graphm DataFlow.il body = mapconst(x -> x in params ? :(self.$x) : x, body) return args, body end @@ -15,7 +15,7 @@ function makegraph(graph, args) @assert length(args) == 1 mapconst(graph) do x x == args[1] ? inputnode(1) : - isa(x, Delay) ? :(Delay($(Expr(:quote, x.name)), self.$(x.name))) : + isa(x, Offset) ? :(Offset($(Expr(:quote, x.name)), $(x.n), self.$(x.name))) : x end end diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index b2f055fe..55da8dd5 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -1,26 +1,19 @@ export unroll -type Delay +type Offset name::Symbol + n::Int default::Nullable{Param} end -Delay(name) = Delay(name, nothing) +Offset(name, n) = Offset(name, n, nothing) -function liftloops!(ex, params) +function liftloops(ex, params) ex = DataFlow.normedges(ex) - hidden = intersect((b.args[1] for b in ex.args), params) - edges = Dict(h => gensym("edge") for h in hidden) - declared = Dict(h => false for h in hidden) - liftvar(s) = get(declared, s, false) ? s : get(edges, s, s) - for b in ex.args - b.args[2] = MacroTools.postwalk(liftvar, b.args[2]) - declared[b.args[1]] = true + MacroTools.postwalk(ex) do ex + @capture(ex, x_{n_}) || return ex + :($(Offset(x,n))($x)) end - for (h, e) in edges - unshift!(ex.args, :($e = $(Delay(h))($h))) - end - return ex end function hasloops(model) @@ -39,39 +32,39 @@ function atomise(model) end end -hiddeninput(n) = vertex(Split(n), inputnode(1)) - -function unroll!(delay::IVertex, n) - prewalk!(delay[1]) do v - v === delay ? hiddeninput(n) : v - end -end - -function break!(g::IVertex) - g = bumpinputs(g) - loops = [] - defaults = [] - g = prewalk!(g) do v - isa(value(v), Delay) || return v - n = length(loops)+1 - push!(loops, unroll!(v, n)) - push!(defaults, get(value(v).default)) - hiddeninput(n) - end - cse(group(group(loops...), g)), defaults -end - -function unrollgraph(model, n) - graph, defaults = break!(atomise(model)) - outputs = [spliceinputs(graph, group([constant(splitnode(inputnode(1),i)) for i = 1:length(defaults)]...), - splitnode(inputnode(2), 1))] - for i = 2:n - push!(outputs, spliceinputs(graph, outputs[end][1], splitnode(inputnode(2), i))) - end - state = outputs[end][1] - outputs = map(x -> x[2], outputs) - (@> group(state, group(outputs...)) detuple), map(x->x.x, defaults) -end +# hiddeninput(n) = vertex(Split(n), inputnode(1)) +# +# function unroll!(delay::IVertex, n) +# prewalk!(delay[1]) do v +# v === delay ? hiddeninput(n) : v +# end +# end +# +# function break!(g::IVertex) +# g = bumpinputs(g) +# loops = [] +# defaults = [] +# g = prewalk!(g) do v +# isa(value(v), Offset) || return v +# n = length(loops)+1 +# push!(loops, unroll!(v, n)) +# push!(defaults, get(value(v).default)) +# hiddeninput(n) +# end +# cse(group(group(loops...), g)), defaults +# end +# +# function unrollgraph(model, n) +# graph, defaults = break!(atomise(model)) +# outputs = [spliceinputs(graph, group([constant(splitnode(inputnode(1),i)) for i = 1:length(defaults)]...), +# splitnode(inputnode(2), 1))] +# for i = 2:n +# push!(outputs, spliceinputs(graph, outputs[end][1], splitnode(inputnode(2), i))) +# end +# state = outputs[end][1] +# outputs = map(x -> x[2], outputs) +# (@> group(state, group(outputs...)) detuple), map(x->x.x, defaults) +# end type Unrolled <: Model model