silo the compiler
This commit is contained in:
parent
8ed4d569b3
commit
bd6bffde48
18
src/Flux.jl
18
src/Flux.jl
@ -2,17 +2,10 @@ __precompile__()
|
||||
|
||||
module Flux
|
||||
|
||||
using MacroTools, Lazy, DataFlow, Juno
|
||||
using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
|
||||
iscyclic, Constant, constant, isconstant, group, Split, splitnode,
|
||||
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
|
||||
spliceinputs, bumpinputs, Line, Frame, applylines, graphinputs
|
||||
using DataFlow.Interpreter
|
||||
using Juno
|
||||
using Lazy: @forward
|
||||
|
||||
export @net, unroll, unroll1, @shapes,
|
||||
@Chain, Chain, Input, Affine, Conv2D, Recurrent, GatedRecurrent, LSTM,
|
||||
σ, relu, softmax,
|
||||
tf, mxnet
|
||||
export Chain, Affine, σ, softmax
|
||||
|
||||
# Zero Flux Given
|
||||
|
||||
@ -23,9 +16,8 @@ export track, back!
|
||||
include("utils.jl")
|
||||
include("params.jl")
|
||||
|
||||
include("compiler/code.jl")
|
||||
include("compiler/loops.jl")
|
||||
include("compiler/interp.jl")
|
||||
include("compiler/Compiler.jl")
|
||||
using .Compiler: @net
|
||||
|
||||
include("layers/control.jl")
|
||||
include("layers/affine.jl")
|
||||
|
14
src/compiler/Compiler.jl
Normal file
14
src/compiler/Compiler.jl
Normal file
@ -0,0 +1,14 @@
|
||||
module Compiler
|
||||
|
||||
using MacroTools, DataFlow, DataFlow.Interpreter
|
||||
|
||||
using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
|
||||
iscyclic, Constant, constant, isconstant, group, Split, splitnode,
|
||||
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
|
||||
spliceinputs, bumpinputs, Line, Frame, applylines, graphinputs
|
||||
|
||||
include("code.jl")
|
||||
include("interp.jl")
|
||||
include("loops.jl")
|
||||
|
||||
end
|
@ -1,5 +1,6 @@
|
||||
import DataFlow: cse
|
||||
using MacroTools: @q, @>
|
||||
import ..Flux: Param, param, state
|
||||
|
||||
graph(m) = nothing
|
||||
|
||||
@ -18,11 +19,11 @@ function makegraph(graph, args, params = [])
|
||||
end
|
||||
graph = map(graph) do x
|
||||
x isa Offset ?
|
||||
:(Flux.Offset($(Expr(:quote, x.name)), $(x.n),
|
||||
:(Flux.Compiler.Offset($(Expr(:quote, x.name)), $(x.n),
|
||||
$(x.name in params ? :(self.$(x.name)) : x.name))) :
|
||||
x
|
||||
end
|
||||
vertex(:(Flux.Frame(self)), graph)
|
||||
vertex(:($DataFlow.Frame(self)), graph)
|
||||
end
|
||||
|
||||
function build_type(T, params)
|
||||
@ -68,7 +69,7 @@ function process_type(ex)
|
||||
quote
|
||||
$(build_type(T, params))
|
||||
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
|
||||
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params))))
|
||||
$(esc(:(Flux.Compiler.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params))))
|
||||
nothing
|
||||
end
|
||||
end
|
||||
|
@ -1,3 +1,5 @@
|
||||
using ..Flux: stack, unstack
|
||||
|
||||
# Stateful Models
|
||||
|
||||
mutable struct Stateful
|
||||
|
@ -8,7 +8,7 @@ end
|
||||
|
||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||
|
||||
graph(s::Chain) =
|
||||
Compiler.graph(s::Chain) =
|
||||
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
|
||||
|
||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||
|
21
src/utils.jl
21
src/utils.jl
@ -15,27 +15,6 @@ convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs
|
||||
|
||||
a ∘ b = a .* b
|
||||
|
||||
broadcastto(xs::AbstractArray, shape) = xs .* ones(shape)
|
||||
|
||||
# Tuples
|
||||
|
||||
mapt(f, x) = f(x)
|
||||
mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs)
|
||||
|
||||
function collectt(xs)
|
||||
ys = []
|
||||
mapt(x -> push!(ys, x), xs)
|
||||
return ys
|
||||
end
|
||||
|
||||
function shapecheckt(xs::Tuple, ys::Tuple)
|
||||
length(xs) == length(ys) || error("Expected tuple length $(length(xs)), got $ys")
|
||||
shapecheckt.(xs, ys)
|
||||
end
|
||||
|
||||
shapecheckt(xs::Tuple, ys) = error("Expected tuple, got $ys")
|
||||
shapecheckt(xs, ys) = nothing
|
||||
|
||||
# Other
|
||||
|
||||
function accuracy(m, data)
|
||||
|
@ -19,7 +19,7 @@ end
|
||||
function test_recurrence(bk)
|
||||
@testset "Recurrence" begin
|
||||
seq = unsqueeze(stack(rand(10) for i = 1:3))
|
||||
r = unroll(Recurrent(10, 5), 3)
|
||||
r = Flux.Compiler.unroll(Recurrent(10, 5), 3)
|
||||
rm = bk(r)
|
||||
@test r(seq) ≈ rm(seq)
|
||||
end
|
||||
|
@ -1,3 +1,5 @@
|
||||
using Flux: Affine
|
||||
|
||||
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
||||
syntax(x) = syntax(graph(x))
|
||||
|
||||
@ -21,12 +23,12 @@ test_anon(identity)
|
||||
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
||||
tlp = TLP(a1, a2)
|
||||
@test tlp(xs) ≈ softmax(a2(σ(a1(xs))))
|
||||
@test Flux.interpmodel(tlp, xs) ≈ softmax(a2(σ(a1(xs))))
|
||||
@test Flux.Compiler.interpmodel(tlp, xs) ≈ softmax(a2(σ(a1(xs))))
|
||||
end
|
||||
|
||||
let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
||||
e = try
|
||||
Flux.interpmodel(tlp, rand(1, 10))
|
||||
Flux.Compiler.interpmodel(tlp, rand(1, 10))
|
||||
catch e
|
||||
e
|
||||
end
|
||||
|
@ -1,3 +1,5 @@
|
||||
using Flux: Recurrent
|
||||
|
||||
function apply(model, xs, state)
|
||||
ys = similar(xs, 0)
|
||||
for x in xs
|
||||
@ -10,8 +12,8 @@ end
|
||||
@testset "RNN unrolling" begin
|
||||
r = Recurrent(10, 5)
|
||||
xs = [rand(1, 10) for _ = 1:3]
|
||||
_, ys = apply(unroll1(r).model, xs, (r.y.x,))
|
||||
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y.x,))
|
||||
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
|
||||
ru = unroll(r, 3)
|
||||
ru = Flux.Compiler.unroll(r, 3)
|
||||
ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys)
|
||||
end
|
||||
|
@ -1,5 +1,6 @@
|
||||
using Flux, DataFlow, MacroTools, Base.Test
|
||||
using Flux: graph, Param, squeeze, unsqueeze, stack, update!, flatten
|
||||
using Flux: Param, param, squeeze, unsqueeze, stack, update!, flatten
|
||||
using Flux.Compiler: @net
|
||||
using DataFlow: Line, Frame
|
||||
|
||||
@testset "Flux" begin
|
||||
|
Loading…
Reference in New Issue
Block a user