silo the compiler
This commit is contained in:
parent
8ed4d569b3
commit
bd6bffde48
18
src/Flux.jl
18
src/Flux.jl
@ -2,17 +2,10 @@ __precompile__()
|
|||||||
|
|
||||||
module Flux
|
module Flux
|
||||||
|
|
||||||
using MacroTools, Lazy, DataFlow, Juno
|
using Juno
|
||||||
using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
|
using Lazy: @forward
|
||||||
iscyclic, Constant, constant, isconstant, group, Split, splitnode,
|
|
||||||
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
|
|
||||||
spliceinputs, bumpinputs, Line, Frame, applylines, graphinputs
|
|
||||||
using DataFlow.Interpreter
|
|
||||||
|
|
||||||
export @net, unroll, unroll1, @shapes,
|
export Chain, Affine, σ, softmax
|
||||||
@Chain, Chain, Input, Affine, Conv2D, Recurrent, GatedRecurrent, LSTM,
|
|
||||||
σ, relu, softmax,
|
|
||||||
tf, mxnet
|
|
||||||
|
|
||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
|
|
||||||
@ -23,9 +16,8 @@ export track, back!
|
|||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("params.jl")
|
include("params.jl")
|
||||||
|
|
||||||
include("compiler/code.jl")
|
include("compiler/Compiler.jl")
|
||||||
include("compiler/loops.jl")
|
using .Compiler: @net
|
||||||
include("compiler/interp.jl")
|
|
||||||
|
|
||||||
include("layers/control.jl")
|
include("layers/control.jl")
|
||||||
include("layers/affine.jl")
|
include("layers/affine.jl")
|
||||||
|
14
src/compiler/Compiler.jl
Normal file
14
src/compiler/Compiler.jl
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
module Compiler
|
||||||
|
|
||||||
|
using MacroTools, DataFlow, DataFlow.Interpreter
|
||||||
|
|
||||||
|
using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
|
||||||
|
iscyclic, Constant, constant, isconstant, group, Split, splitnode,
|
||||||
|
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
|
||||||
|
spliceinputs, bumpinputs, Line, Frame, applylines, graphinputs
|
||||||
|
|
||||||
|
include("code.jl")
|
||||||
|
include("interp.jl")
|
||||||
|
include("loops.jl")
|
||||||
|
|
||||||
|
end
|
@ -1,5 +1,6 @@
|
|||||||
import DataFlow: cse
|
import DataFlow: cse
|
||||||
using MacroTools: @q, @>
|
using MacroTools: @q, @>
|
||||||
|
import ..Flux: Param, param, state
|
||||||
|
|
||||||
graph(m) = nothing
|
graph(m) = nothing
|
||||||
|
|
||||||
@ -18,11 +19,11 @@ function makegraph(graph, args, params = [])
|
|||||||
end
|
end
|
||||||
graph = map(graph) do x
|
graph = map(graph) do x
|
||||||
x isa Offset ?
|
x isa Offset ?
|
||||||
:(Flux.Offset($(Expr(:quote, x.name)), $(x.n),
|
:(Flux.Compiler.Offset($(Expr(:quote, x.name)), $(x.n),
|
||||||
$(x.name in params ? :(self.$(x.name)) : x.name))) :
|
$(x.name in params ? :(self.$(x.name)) : x.name))) :
|
||||||
x
|
x
|
||||||
end
|
end
|
||||||
vertex(:(Flux.Frame(self)), graph)
|
vertex(:($DataFlow.Frame(self)), graph)
|
||||||
end
|
end
|
||||||
|
|
||||||
function build_type(T, params)
|
function build_type(T, params)
|
||||||
@ -68,7 +69,7 @@ function process_type(ex)
|
|||||||
quote
|
quote
|
||||||
$(build_type(T, params))
|
$(build_type(T, params))
|
||||||
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
|
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
|
||||||
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params))))
|
$(esc(:(Flux.Compiler.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params))))
|
||||||
nothing
|
nothing
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
using ..Flux: stack, unstack
|
||||||
|
|
||||||
# Stateful Models
|
# Stateful Models
|
||||||
|
|
||||||
mutable struct Stateful
|
mutable struct Stateful
|
||||||
|
@ -8,7 +8,7 @@ end
|
|||||||
|
|
||||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||||
|
|
||||||
graph(s::Chain) =
|
Compiler.graph(s::Chain) =
|
||||||
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
|
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
|
||||||
|
|
||||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||||
|
21
src/utils.jl
21
src/utils.jl
@ -15,27 +15,6 @@ convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs
|
|||||||
|
|
||||||
a ∘ b = a .* b
|
a ∘ b = a .* b
|
||||||
|
|
||||||
broadcastto(xs::AbstractArray, shape) = xs .* ones(shape)
|
|
||||||
|
|
||||||
# Tuples
|
|
||||||
|
|
||||||
mapt(f, x) = f(x)
|
|
||||||
mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs)
|
|
||||||
|
|
||||||
function collectt(xs)
|
|
||||||
ys = []
|
|
||||||
mapt(x -> push!(ys, x), xs)
|
|
||||||
return ys
|
|
||||||
end
|
|
||||||
|
|
||||||
function shapecheckt(xs::Tuple, ys::Tuple)
|
|
||||||
length(xs) == length(ys) || error("Expected tuple length $(length(xs)), got $ys")
|
|
||||||
shapecheckt.(xs, ys)
|
|
||||||
end
|
|
||||||
|
|
||||||
shapecheckt(xs::Tuple, ys) = error("Expected tuple, got $ys")
|
|
||||||
shapecheckt(xs, ys) = nothing
|
|
||||||
|
|
||||||
# Other
|
# Other
|
||||||
|
|
||||||
function accuracy(m, data)
|
function accuracy(m, data)
|
||||||
|
@ -19,7 +19,7 @@ end
|
|||||||
function test_recurrence(bk)
|
function test_recurrence(bk)
|
||||||
@testset "Recurrence" begin
|
@testset "Recurrence" begin
|
||||||
seq = unsqueeze(stack(rand(10) for i = 1:3))
|
seq = unsqueeze(stack(rand(10) for i = 1:3))
|
||||||
r = unroll(Recurrent(10, 5), 3)
|
r = Flux.Compiler.unroll(Recurrent(10, 5), 3)
|
||||||
rm = bk(r)
|
rm = bk(r)
|
||||||
@test r(seq) ≈ rm(seq)
|
@test r(seq) ≈ rm(seq)
|
||||||
end
|
end
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
using Flux: Affine
|
||||||
|
|
||||||
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
||||||
syntax(x) = syntax(graph(x))
|
syntax(x) = syntax(graph(x))
|
||||||
|
|
||||||
@ -21,12 +23,12 @@ test_anon(identity)
|
|||||||
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
||||||
tlp = TLP(a1, a2)
|
tlp = TLP(a1, a2)
|
||||||
@test tlp(xs) ≈ softmax(a2(σ(a1(xs))))
|
@test tlp(xs) ≈ softmax(a2(σ(a1(xs))))
|
||||||
@test Flux.interpmodel(tlp, xs) ≈ softmax(a2(σ(a1(xs))))
|
@test Flux.Compiler.interpmodel(tlp, xs) ≈ softmax(a2(σ(a1(xs))))
|
||||||
end
|
end
|
||||||
|
|
||||||
let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
||||||
e = try
|
e = try
|
||||||
Flux.interpmodel(tlp, rand(1, 10))
|
Flux.Compiler.interpmodel(tlp, rand(1, 10))
|
||||||
catch e
|
catch e
|
||||||
e
|
e
|
||||||
end
|
end
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
using Flux: Recurrent
|
||||||
|
|
||||||
function apply(model, xs, state)
|
function apply(model, xs, state)
|
||||||
ys = similar(xs, 0)
|
ys = similar(xs, 0)
|
||||||
for x in xs
|
for x in xs
|
||||||
@ -10,8 +12,8 @@ end
|
|||||||
@testset "RNN unrolling" begin
|
@testset "RNN unrolling" begin
|
||||||
r = Recurrent(10, 5)
|
r = Recurrent(10, 5)
|
||||||
xs = [rand(1, 10) for _ = 1:3]
|
xs = [rand(1, 10) for _ = 1:3]
|
||||||
_, ys = apply(unroll1(r).model, xs, (r.y.x,))
|
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y.x,))
|
||||||
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
|
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
|
||||||
ru = unroll(r, 3)
|
ru = Flux.Compiler.unroll(r, 3)
|
||||||
ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys)
|
ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys)
|
||||||
end
|
end
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
using Flux, DataFlow, MacroTools, Base.Test
|
using Flux, DataFlow, MacroTools, Base.Test
|
||||||
using Flux: graph, Param, squeeze, unsqueeze, stack, update!, flatten
|
using Flux: Param, param, squeeze, unsqueeze, stack, update!, flatten
|
||||||
|
using Flux.Compiler: @net
|
||||||
using DataFlow: Line, Frame
|
using DataFlow: Line, Frame
|
||||||
|
|
||||||
@testset "Flux" begin
|
@testset "Flux" begin
|
||||||
|
Loading…
Reference in New Issue
Block a user