From f52b0140a50589267a6000d6a3e6921b95f04474 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Sat, 13 Aug 2016 00:33:39 +0100 Subject: [PATCH] remove loop compilation --- src/Flux.jl | 1 - src/compiler/code.jl | 20 +-------------- src/compiler/loop.jl | 59 -------------------------------------------- 3 files changed, 1 insertion(+), 79 deletions(-) delete mode 100644 src/compiler/loop.jl diff --git a/src/Flux.jl b/src/Flux.jl index bee668ec..c48d4855 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -15,7 +15,6 @@ update!(m::Model, η) = m include("capacitor.jl") include("compiler/diff.jl") -include("compiler/loop.jl") include("compiler/code.jl") include("cost.jl") diff --git a/src/compiler/code.jl b/src/compiler/code.jl index 188b9116..335c7b71 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -17,27 +17,20 @@ function build_type(T, params) end function build_forward(body, args) - body = cut_forward(body, args) cse(body) end function build_backward(body, x, params) - Δs, Δloops = cut_backward(body, [x]) + Δs = invert(body) back = IVertex{Any}(Flow.Do()) for param in params haskey(Δs, :(self.$param)) || continue k = symbol("Δ", param) ksym = Expr(:quote, k) ex = Δs[:(self.$param)] - for Δloop in Δloops - ex = addΔ(ex, get(Δloop, :(self.$param), vertex(0))) - end thread!(back, @v(setfield!(:self, ksym, :(self.$k) + ex))) end ex = Δs[x] - for Δloop in Δloops - ex = addΔ(ex, get(Δloop, x, vertex(0))) - end thread!(back, @flow(tuple($ex))) cse(back) end @@ -69,16 +62,5 @@ end # process_type(:(type Sigmoid # W # b -# bp # x -> σ(W*x+b) # end)) |> prettify - -process_type(:(type Recurrent - Wxh; Whh; Bh - Why; By - - function (x) - hidden = σ( Wxh*x + Whh*Delay(hidden) + Bh ) - y = σ( Why*hidden + By ) - end -end)) |> prettify diff --git a/src/compiler/loop.jl b/src/compiler/loop.jl deleted file mode 100644 index 91c02deb..00000000 --- a/src/compiler/loop.jl +++ /dev/null @@ -1,59 +0,0 @@ -function delays(v::IVertex) - ds = [] - Flow.prefor(v) do w - value(w) == :Delay && - push!(ds, w) - end - return ds -end - -function cut(v::IVertex, f = _ -> il(@flow(last(self.delay)))) - prewalk(v) do v - value(v) == :Delay ? f(v) : v - end -end - -replaceall(d::Dict, args...) = Dict(k => replace(v, args...) for (k, v) in d) - -# Create the forward function; a single delay node becomes an -# input and an output node. -function cut_forward(v::IVertex, params, ds = delays(v)) - pushes = map(x->vertex(:push!, vertex(:(self.delay)), x[1], map(vertex, params)...), ds) - isempty(pushes) && return v - @assert length(pushes) == 1 - v = vertex(Flow.Do(), pushes..., v) - cut(v) -end - -# Given a delay node, give the parameter gradients with respect to -# the node and a function which will propagate gradients around -# the loop. -function invertloop(v::IVertex, params) - @gensym input - v = cut(v[1], v -> vertex(input)) - Δs = invert(v, @flow(Δloop)) - Δs = replaceall(Δs, vertex(input), il(@flow(last(self.delay)))) - Δs, :((Δ, $input, $(params...)) -> $(syntax(cse(Δs[input])))) -end - -# Returns: -# Parameter gradients with respect to the function -# Parameter gradients with respect to each delay node -function cut_backward(v::IVertex, params, ds = delays(v)) - isempty(ds) && return invert(v), [] - @assert length(ds) == 1 - @gensym input - Δs = invert(cut(v, _ -> vertex(input))) - Δs = replaceall(Δs, vertex(input), il(@flow(last(self.delay)))) - Δloop, ∇loop = invertloop(ds[1], params) - Δh = vertex(:back!, vertex(:(self.delay)), Δs[input], vertex(∇loop)) - Δloop = replaceall(Δloop, vertex(:Δloop), Δh) - Δs, [Δloop] -end - -# g = il(@flow begin -# hidden = σ( Wxh*x + Whh*Delay(hidden) + bh ) -# y = σ( Why*hidden + by ) -# end) - -# cut_backward(g, [:x])[1]