silo the compiler

This commit is contained in:
Mike J Innes 2017-08-19 20:04:21 +01:00
parent 8ed4d569b3
commit bd6bffde48
10 changed files with 37 additions and 44 deletions

View File

@ -2,17 +2,10 @@ __precompile__()
module Flux
using MacroTools, Lazy, DataFlow, Juno
using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
iscyclic, Constant, constant, isconstant, group, Split, splitnode,
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
spliceinputs, bumpinputs, Line, Frame, applylines, graphinputs
using DataFlow.Interpreter
using Juno
using Lazy: @forward
export @net, unroll, unroll1, @shapes,
@Chain, Chain, Input, Affine, Conv2D, Recurrent, GatedRecurrent, LSTM,
σ, relu, softmax,
tf, mxnet
export Chain, Affine, σ, softmax
# Zero Flux Given
@ -23,9 +16,8 @@ export track, back!
include("utils.jl")
include("params.jl")
include("compiler/code.jl")
include("compiler/loops.jl")
include("compiler/interp.jl")
include("compiler/Compiler.jl")
using .Compiler: @net
include("layers/control.jl")
include("layers/affine.jl")

14
src/compiler/Compiler.jl Normal file
View File

@ -0,0 +1,14 @@
module Compiler
using MacroTools, DataFlow, DataFlow.Interpreter
using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
iscyclic, Constant, constant, isconstant, group, Split, splitnode,
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
spliceinputs, bumpinputs, Line, Frame, applylines, graphinputs
include("code.jl")
include("interp.jl")
include("loops.jl")
end

View File

@ -1,5 +1,6 @@
import DataFlow: cse
using MacroTools: @q, @>
import ..Flux: Param, param, state
graph(m) = nothing
@ -18,11 +19,11 @@ function makegraph(graph, args, params = [])
end
graph = map(graph) do x
x isa Offset ?
:(Flux.Offset($(Expr(:quote, x.name)), $(x.n),
:(Flux.Compiler.Offset($(Expr(:quote, x.name)), $(x.n),
$(x.name in params ? :(self.$(x.name)) : x.name))) :
x
end
vertex(:(Flux.Frame(self)), graph)
vertex(:($DataFlow.Frame(self)), graph)
end
function build_type(T, params)
@ -68,7 +69,7 @@ function process_type(ex)
quote
$(build_type(T, params))
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params))))
$(esc(:(Flux.Compiler.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params))))
nothing
end
end

View File

@ -1,3 +1,5 @@
using ..Flux: stack, unstack
# Stateful Models
mutable struct Stateful

View File

@ -8,7 +8,7 @@ end
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
graph(s::Chain) =
Compiler.graph(s::Chain) =
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)

View File

@ -15,27 +15,6 @@ convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs
a b = a .* b
broadcastto(xs::AbstractArray, shape) = xs .* ones(shape)
# Tuples
mapt(f, x) = f(x)
mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs)
function collectt(xs)
ys = []
mapt(x -> push!(ys, x), xs)
return ys
end
function shapecheckt(xs::Tuple, ys::Tuple)
length(xs) == length(ys) || error("Expected tuple length $(length(xs)), got $ys")
shapecheckt.(xs, ys)
end
shapecheckt(xs::Tuple, ys) = error("Expected tuple, got $ys")
shapecheckt(xs, ys) = nothing
# Other
function accuracy(m, data)

View File

@ -19,7 +19,7 @@ end
function test_recurrence(bk)
@testset "Recurrence" begin
seq = unsqueeze(stack(rand(10) for i = 1:3))
r = unroll(Recurrent(10, 5), 3)
r = Flux.Compiler.unroll(Recurrent(10, 5), 3)
rm = bk(r)
@test r(seq) rm(seq)
end

View File

@ -1,3 +1,5 @@
using Flux: Affine
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
syntax(x) = syntax(graph(x))
@ -21,12 +23,12 @@ test_anon(identity)
let a1 = Affine(10, 20), a2 = Affine(20, 15)
tlp = TLP(a1, a2)
@test tlp(xs) softmax(a2(σ(a1(xs))))
@test Flux.interpmodel(tlp, xs) softmax(a2(σ(a1(xs))))
@test Flux.Compiler.interpmodel(tlp, xs) softmax(a2(σ(a1(xs))))
end
let tlp = TLP(Affine(10, 21), Affine(20, 15))
e = try
Flux.interpmodel(tlp, rand(1, 10))
Flux.Compiler.interpmodel(tlp, rand(1, 10))
catch e
e
end

View File

@ -1,3 +1,5 @@
using Flux: Recurrent
function apply(model, xs, state)
ys = similar(xs, 0)
for x in xs
@ -10,8 +12,8 @@ end
@testset "RNN unrolling" begin
r = Recurrent(10, 5)
xs = [rand(1, 10) for _ = 1:3]
_, ys = apply(unroll1(r).model, xs, (r.y.x,))
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y.x,))
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
ru = unroll(r, 3)
ru = Flux.Compiler.unroll(r, 3)
ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys)
end

View File

@ -1,5 +1,6 @@
using Flux, DataFlow, MacroTools, Base.Test
using Flux: graph, Param, squeeze, unsqueeze, stack, update!, flatten
using Flux: Param, param, squeeze, unsqueeze, stack, update!, flatten
using Flux.Compiler: @net
using DataFlow: Line, Frame
@testset "Flux" begin