From bd6bffde48ef0481f5a0a27e921761c7ebe57e99 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Sat, 19 Aug 2017 20:04:21 +0100 Subject: [PATCH] silo the compiler --- src/Flux.jl | 18 +++++------------- src/compiler/Compiler.jl | 14 ++++++++++++++ src/compiler/code.jl | 7 ++++--- src/compiler/loops.jl | 2 ++ src/layers/control.jl | 2 +- src/utils.jl | 21 --------------------- test/backend/common.jl | 2 +- test/basic.jl | 6 ++++-- test/recurrent.jl | 6 ++++-- test/runtests.jl | 3 ++- 10 files changed, 37 insertions(+), 44 deletions(-) create mode 100644 src/compiler/Compiler.jl diff --git a/src/Flux.jl b/src/Flux.jl index 742c8b0e..0db1ed7e 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -2,17 +2,10 @@ __precompile__() module Flux -using MacroTools, Lazy, DataFlow, Juno -using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk, - iscyclic, Constant, constant, isconstant, group, Split, splitnode, - detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode, - spliceinputs, bumpinputs, Line, Frame, applylines, graphinputs -using DataFlow.Interpreter +using Juno +using Lazy: @forward -export @net, unroll, unroll1, @shapes, - @Chain, Chain, Input, Affine, Conv2D, Recurrent, GatedRecurrent, LSTM, - σ, relu, softmax, - tf, mxnet +export Chain, Affine, σ, softmax # Zero Flux Given @@ -23,9 +16,8 @@ export track, back! include("utils.jl") include("params.jl") -include("compiler/code.jl") -include("compiler/loops.jl") -include("compiler/interp.jl") +include("compiler/Compiler.jl") +using .Compiler: @net include("layers/control.jl") include("layers/affine.jl") diff --git a/src/compiler/Compiler.jl b/src/compiler/Compiler.jl new file mode 100644 index 00000000..ef65c483 --- /dev/null +++ b/src/compiler/Compiler.jl @@ -0,0 +1,14 @@ +module Compiler + +using MacroTools, DataFlow, DataFlow.Interpreter + +using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk, + iscyclic, Constant, constant, isconstant, group, Split, splitnode, + detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode, + spliceinputs, bumpinputs, Line, Frame, applylines, graphinputs + +include("code.jl") +include("interp.jl") +include("loops.jl") + +end diff --git a/src/compiler/code.jl b/src/compiler/code.jl index 2554fe63..027f862c 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -1,5 +1,6 @@ import DataFlow: cse using MacroTools: @q, @> +import ..Flux: Param, param, state graph(m) = nothing @@ -18,11 +19,11 @@ function makegraph(graph, args, params = []) end graph = map(graph) do x x isa Offset ? - :(Flux.Offset($(Expr(:quote, x.name)), $(x.n), + :(Flux.Compiler.Offset($(Expr(:quote, x.name)), $(x.n), $(x.name in params ? :(self.$(x.name)) : x.name))) : x end - vertex(:(Flux.Frame(self)), graph) + vertex(:($DataFlow.Frame(self)), graph) end function build_type(T, params) @@ -68,7 +69,7 @@ function process_type(ex) quote $(build_type(T, params)) $(esc(:((self::$T)($(args...)) = $(build_forward(body, args))))) - $(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params)))) + $(esc(:(Flux.Compiler.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params)))) nothing end end diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index 50104507..dc228526 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -1,3 +1,5 @@ +using ..Flux: stack, unstack + # Stateful Models mutable struct Stateful diff --git a/src/layers/control.jl b/src/layers/control.jl index fca15c0d..0003812c 100644 --- a/src/layers/control.jl +++ b/src/layers/control.jl @@ -8,7 +8,7 @@ end (s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers) -graph(s::Chain) = +Compiler.graph(s::Chain) = foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) diff --git a/src/utils.jl b/src/utils.jl index c77202e1..958d04e4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -15,27 +15,6 @@ convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs a ∘ b = a .* b -broadcastto(xs::AbstractArray, shape) = xs .* ones(shape) - -# Tuples - -mapt(f, x) = f(x) -mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs) - -function collectt(xs) - ys = [] - mapt(x -> push!(ys, x), xs) - return ys -end - -function shapecheckt(xs::Tuple, ys::Tuple) - length(xs) == length(ys) || error("Expected tuple length $(length(xs)), got $ys") - shapecheckt.(xs, ys) -end - -shapecheckt(xs::Tuple, ys) = error("Expected tuple, got $ys") -shapecheckt(xs, ys) = nothing - # Other function accuracy(m, data) diff --git a/test/backend/common.jl b/test/backend/common.jl index e001f436..eb6b644b 100644 --- a/test/backend/common.jl +++ b/test/backend/common.jl @@ -19,7 +19,7 @@ end function test_recurrence(bk) @testset "Recurrence" begin seq = unsqueeze(stack(rand(10) for i = 1:3)) - r = unroll(Recurrent(10, 5), 3) + r = Flux.Compiler.unroll(Recurrent(10, 5), 3) rm = bk(r) @test r(seq) ≈ rm(seq) end diff --git a/test/basic.jl b/test/basic.jl index d2b28d3c..efd74f4c 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1,3 +1,5 @@ +using Flux: Affine + syntax(v::Vertex) = prettify(DataFlow.syntax(v)) syntax(x) = syntax(graph(x)) @@ -21,12 +23,12 @@ test_anon(identity) let a1 = Affine(10, 20), a2 = Affine(20, 15) tlp = TLP(a1, a2) @test tlp(xs) ≈ softmax(a2(σ(a1(xs)))) - @test Flux.interpmodel(tlp, xs) ≈ softmax(a2(σ(a1(xs)))) + @test Flux.Compiler.interpmodel(tlp, xs) ≈ softmax(a2(σ(a1(xs)))) end let tlp = TLP(Affine(10, 21), Affine(20, 15)) e = try - Flux.interpmodel(tlp, rand(1, 10)) + Flux.Compiler.interpmodel(tlp, rand(1, 10)) catch e e end diff --git a/test/recurrent.jl b/test/recurrent.jl index 236a5b59..eb048106 100644 --- a/test/recurrent.jl +++ b/test/recurrent.jl @@ -1,3 +1,5 @@ +using Flux: Recurrent + function apply(model, xs, state) ys = similar(xs, 0) for x in xs @@ -10,8 +12,8 @@ end @testset "RNN unrolling" begin r = Recurrent(10, 5) xs = [rand(1, 10) for _ = 1:3] - _, ys = apply(unroll1(r).model, xs, (r.y.x,)) + _, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y.x,)) @test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x) - ru = unroll(r, 3) + ru = Flux.Compiler.unroll(r, 3) ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys) end diff --git a/test/runtests.jl b/test/runtests.jl index 39c56ee4..9f899897 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using Flux, DataFlow, MacroTools, Base.Test -using Flux: graph, Param, squeeze, unsqueeze, stack, update!, flatten +using Flux: Param, param, squeeze, unsqueeze, stack, update!, flatten +using Flux.Compiler: @net using DataFlow: Line, Frame @testset "Flux" begin