From 9c9bfa6f138c577dcd764182266a5dec065d7f50 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 29 Aug 2016 15:17:54 +0100 Subject: [PATCH] loop lifting --- src/Flux.jl | 1 + src/compiler/code.jl | 2 +- src/compiler/loops.jl | 49 +++++++++++++++++++++++++++++++++++++++++++ src/layers/dense.jl | 10 --------- 4 files changed, 51 insertions(+), 11 deletions(-) create mode 100644 src/compiler/loops.jl diff --git a/src/Flux.jl b/src/Flux.jl index f4121a35..ccc69c0e 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -9,6 +9,7 @@ include("utils.jl") include("compiler/diff.jl") include("compiler/code.jl") +include("compiler/loops.jl") include("layers/dense.jl") include("layers/shape.jl") diff --git a/src/compiler/code.jl b/src/compiler/code.jl index fd14507c..352e9ca5 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -6,7 +6,7 @@ export @model function process_func(ex, params = []) @capture(shortdef(ex), (args__,) -> body_) - body = Flow.il(graphm(unblock(body))) + body = @> body MacroTools.flatten liftloops!(params) graphm Flow.il body = mapconst(x -> x in params ? :(self.$x) : x, body) return args, body end diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl new file mode 100644 index 00000000..a48c541c --- /dev/null +++ b/src/compiler/loops.jl @@ -0,0 +1,49 @@ +type Delay + name::Symbol +end + +function liftloops!(ex, params) + e = Flow.normedges(ex) + hidden = intersect((b.args[1] for b in ex.args), params) + edges = Dict(h => gensym("edge") for h in hidden) + for b in ex.args + b.args[2] = MacroTools.postwalk(x -> get(edges, x, x), b.args[2]) + end + for (h, e) in edges + unshift!(ex.args, :($e = $(Delay(h))($h))) + end + return ex +end + +bumpinput(i::ModelInput) = isa(i.name, Integer) ? ModelInput(i.name + 1) : i +bumpinput(x) = x + +bumpinputs(v::IVertex) = mapconst(bumpinput, v) + +function break!(model) + iscyclic(graph(model)) || return model + bumpinputs(graph(model)) +end + +# r = Recurrent(784, 10, 50) + +# break!(r) + +@model type Recurrent + Wxh; Whh; Bh + Wxy; Why; By + hidden + + function (x) + hidden = σ( Wxh*x + Whh*hidden + Bh ) + y = σ( Wxy*x + Why*hidden + By ) + end +end + +Recurrent(in::Integer, out::Integer, hidden::Integer; init = initn) = + Recurrent(init(hidden, in), init(hidden, hidden), init(hidden), + init(out, in), init(out, hidden), init(hidden), + zeros(hidden)) + +Base.show(io::IO, r::Recurrent) = + print(io, "Flux.Recurrent(...)") diff --git a/src/layers/dense.jl b/src/layers/dense.jl index 1a3fa302..e2c3682e 100644 --- a/src/layers/dense.jl +++ b/src/layers/dense.jl @@ -21,13 +21,3 @@ end Sigmoid(in::Integer, out::Integer; init = randn) = Sigmoid(Dense(in, out, init = init)) - -# @model type Recurrent -# Wxh; Whh; Bh -# Wxy; Why; By -# -# function (x) -# hidden = σ( Wxh*x + Whh*hidden + Bh ) -# y = σ( Wxy*x + Why*hidden + By ) -# end -# end