remove compiler
This commit is contained in:
parent
96d1c55263
commit
2ec8401d2c
@ -22,8 +22,6 @@ using .Optimise
|
||||
include("utils.jl")
|
||||
include("onehot.jl")
|
||||
|
||||
include("compiler/Compiler.jl")
|
||||
|
||||
include("layers/stateless.jl")
|
||||
include("layers/basic.jl")
|
||||
include("layers/recurrent.jl")
|
||||
|
@ -1,14 +0,0 @@
|
||||
module Compiler
|
||||
|
||||
using MacroTools, DataFlow, DataFlow.Interpreter
|
||||
|
||||
using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
|
||||
iscyclic, Constant, constant, isconstant, group, Split,
|
||||
detuple, value, inputs, thread!, value, inputs, inputnode,
|
||||
spliceinputs, bumpinputs, Line, Frame, applylines, graphinputs
|
||||
|
||||
include("code.jl")
|
||||
include("interp.jl")
|
||||
include("loops.jl")
|
||||
|
||||
end
|
@ -1,77 +0,0 @@
|
||||
import DataFlow: cse
|
||||
using MacroTools: @q, @>
|
||||
|
||||
graph(m) = nothing
|
||||
|
||||
function graphdef(ex, params = [])
|
||||
@capture(shortdef(ex), (args__,) -> body_)
|
||||
body = @> body MacroTools.flatten liftloops graphm DataFlow.il
|
||||
body = map(x -> x in params ? :(self.$x) : x, body)
|
||||
return args, body
|
||||
end
|
||||
|
||||
function makegraph(graph, args, params = [])
|
||||
graph = prewalk(graph) do v
|
||||
isconstant(v) && (i = findfirst(args, value(v[1]))) ≠ 0 ?
|
||||
inputnode(i) :
|
||||
v
|
||||
end
|
||||
graph = map(graph) do x
|
||||
x isa Offset ?
|
||||
:(Flux.Compiler.Offset($(Expr(:quote, x.name)), $(x.n),
|
||||
$(x.name in params ? :(self.$(x.name)) : x.name))) :
|
||||
x
|
||||
end
|
||||
vertex(:($DataFlow.Frame(self)), graph)
|
||||
end
|
||||
|
||||
function build_type(T, params)
|
||||
@esc T
|
||||
:(type $T
|
||||
$(params...)
|
||||
end)
|
||||
end
|
||||
|
||||
function build_forward(body, args)
|
||||
iscyclic(body) && return :(error("Can't run forward pass on a cyclic graph"))
|
||||
applylines(syntax(cse(body)))
|
||||
end
|
||||
|
||||
import Lazy: groupby
|
||||
|
||||
# TODO: type hints for parameters
|
||||
|
||||
function process_type(ex)
|
||||
@capture(ex, type T_ fs__ end)
|
||||
@destruct [params = false || [],
|
||||
funcs = true || []] = groupby(x->isexpr(x, :->, :function), fs)
|
||||
@assert length(funcs) == 1
|
||||
pnames = namify.(params)
|
||||
args, body = graphdef(funcs[1], pnames)
|
||||
self = esc(:self)
|
||||
quote
|
||||
$(build_type(T, params))
|
||||
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
|
||||
$(esc(:(Flux.Compiler.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params))))
|
||||
nothing
|
||||
end
|
||||
end
|
||||
|
||||
function process_anon(ex)
|
||||
args, body = graphdef(ex)
|
||||
:(Capacitor($(DataFlow.constructor(map(esc, makegraph(body, args)[1])))))
|
||||
end
|
||||
|
||||
function process_def(ex)
|
||||
# TODO: make a singleton net type
|
||||
@capture(ex, f_(xs__) = body_)
|
||||
:($(esc(f)) = @net $(esc(:(($(xs...),) -> $body))); nothing)
|
||||
end
|
||||
|
||||
macro net(ex)
|
||||
ex = shortdef(ex)
|
||||
isexpr(ex, :type) ? process_type(ex) :
|
||||
@capture(ex, (__,) -> _) ? process_anon(ex) :
|
||||
@capture(ex, _(__) = _) ? process_def(ex) :
|
||||
error("Unsupported model expression $ex")
|
||||
end
|
@ -1,39 +0,0 @@
|
||||
function astuple(xs::Vertex)
|
||||
isconstant(xs) && value(xs[1]) isa Tuple ? value(xs[1]) :
|
||||
xs isa Vertex && value(xs) == tuple ? inputs(xs) :
|
||||
nothing
|
||||
end
|
||||
|
||||
astuple(xs::Tuple) = xs
|
||||
|
||||
astuple(xs) = nothing
|
||||
|
||||
function astuples(xs)
|
||||
xs = [astuple(x) for x in xs]
|
||||
all(x->!(x==nothing), xs) ? xs : nothing
|
||||
end
|
||||
|
||||
function interp(ctx, f, xs...)
|
||||
g = graph(f)
|
||||
g ≠ nothing && iscyclic(g) && error("Can't interpret cyclic graph")
|
||||
@icatch(ctx, g ≠ nothing ?
|
||||
interpret(ctx, g, xs...) :
|
||||
f(xs...))
|
||||
end
|
||||
|
||||
function interpmodel(m, args...)
|
||||
ctx = Context(mux(iconst, iline, ilambda, iargs, ituple, interp))
|
||||
@ithrow interp(ctx, m, args...)
|
||||
end
|
||||
|
||||
# Anonymous models
|
||||
|
||||
struct Capacitor
|
||||
graph::IVertex{Any}
|
||||
end
|
||||
|
||||
(m::Capacitor)(xs...) = interpmodel(m, xs...)
|
||||
|
||||
graph(cap::Capacitor) = cap.graph
|
||||
|
||||
Base.show(io::IO, ::Capacitor) = print(io, "Capacitor(...)")
|
@ -1,191 +0,0 @@
|
||||
using ..Flux: stack, unstack, squeeze, unsqueeze
|
||||
|
||||
# Stateful Models
|
||||
|
||||
mutable struct Stateful
|
||||
model
|
||||
states::Vector{Any}
|
||||
istate::Vector{Any}
|
||||
ostate::Vector{Any}
|
||||
end
|
||||
|
||||
Stateful(model, ss) = Stateful(model, ss, ss, ss)
|
||||
|
||||
function Base.show(io::IO, m::Stateful)
|
||||
print(io, "Stateful(")
|
||||
show(io, m.model)
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
function (m::Stateful)(xs...)
|
||||
m.istate = m.ostate
|
||||
state, y = m.model((m.istate...,), xs...)
|
||||
m.ostate = collect(state)
|
||||
return y
|
||||
end
|
||||
|
||||
# Seq Models
|
||||
|
||||
struct SeqModel
|
||||
model
|
||||
steps::Int
|
||||
end
|
||||
|
||||
seqtuple(x, n) = x
|
||||
seqtuple(xs::Tuple, n) = seqtuple.(xs, n)
|
||||
|
||||
seqtuple(xs::AbstractArray, n) =
|
||||
ndims(xs) < 3 ? xs :
|
||||
n ≠ 0 && size(xs, 2) ≠ n ? error("Expecting sequence length $n, got $(size(xs, 2))") :
|
||||
(unstack(xs, 2)...)
|
||||
|
||||
reseq(x) = x
|
||||
reseq(x::Tuple{}) = ()
|
||||
reseq(xs::Tuple) = all(isa.(xs, AbstractArray) .& (ndims.(xs) .≥ 2)) ? stack(xs, 2) : reseq.(xs)
|
||||
|
||||
function (m::SeqModel)(xs...)
|
||||
xs = seqtuple(xs, m.steps)
|
||||
reseq(m.model(xs...))
|
||||
end
|
||||
|
||||
graph(m::SeqModel) = graph(m.model)
|
||||
|
||||
# Recurrent Graphs
|
||||
|
||||
struct Offset
|
||||
name::Symbol
|
||||
n::Int
|
||||
default::Nullable{Any}
|
||||
end
|
||||
|
||||
Offset(name, n) = Offset(name, n, nothing)
|
||||
|
||||
Base.:-(o::Offset) = Offset(o.name, -o.n, o.default)
|
||||
|
||||
function liftloops(ex)
|
||||
ex = DataFlow.normedges(ex)
|
||||
decls = Dict()
|
||||
ex = MacroTools.postwalk(ex) do ex
|
||||
@capture(ex, x_{n_}) || return ex
|
||||
haskey(decls, (x,n)) && return namify(decls[(x,n)])
|
||||
@gensym edge
|
||||
decls[(x,n)] = :($edge = $(Offset(x,n))($x))
|
||||
edge
|
||||
end
|
||||
prepend!(ex.args, collect(values(decls)))
|
||||
ex
|
||||
end
|
||||
|
||||
function hasloops(model)
|
||||
g = graph(model)
|
||||
g == nothing && return false
|
||||
iscyclic(g) && return true
|
||||
result = false
|
||||
map(m -> hasloops(m) && (result = true), g)
|
||||
return result
|
||||
end
|
||||
|
||||
function atomise(model)
|
||||
postwalk(graph(model)) do v
|
||||
hasloops(value(v)) || return v
|
||||
spliceinputs(atomise(value(v)), inputs(v)...)
|
||||
end
|
||||
end
|
||||
|
||||
function collect_state(v::IVertex)
|
||||
state = typeof(v)[]
|
||||
offset = Int[]
|
||||
default = []
|
||||
prewalk!(v) do v
|
||||
value(v) isa Offset || return v
|
||||
if (i = findfirst(state, v[1])) == 0
|
||||
push!(state, v[1])
|
||||
push!(offset, max(0, -value(v).n))
|
||||
push!(default, get(value(v).default))
|
||||
else
|
||||
offset[i] = max(offset[i], -value(v).n)
|
||||
end
|
||||
v
|
||||
end
|
||||
return state, offset, default
|
||||
end
|
||||
|
||||
hiddeninput(n, t) = vertex(Split(t), inputnode(n))
|
||||
|
||||
# TODO: nicer way to do this.
|
||||
create_steps(v::IVertex, n) = [bumpinputs(spliceinputs(v, [hiddeninput(n, t) for n = 1:graphinputs(v)]...)) for t = 1:n]
|
||||
|
||||
function getvar(n, step, steps, offset, default)
|
||||
if step < 1
|
||||
hiddeninput(1, sum(offset[1:n-1]) + 1 - step)
|
||||
elseif step ∉ 1:length(steps)
|
||||
constant(default[n])
|
||||
else
|
||||
steps[step][1,n]
|
||||
end
|
||||
end
|
||||
|
||||
function stateout(steps, offset, default)
|
||||
outs = []
|
||||
defaults = []
|
||||
for i = 1:length(offset), j = 1:offset[i]
|
||||
push!(outs, getvar(i, length(steps)-j+1, steps, offset, default))
|
||||
push!(defaults, default[i])
|
||||
end
|
||||
group(outs...), defaults
|
||||
end
|
||||
|
||||
# Input: (hidden1, hidden2, ...), (x1, x2, ...)
|
||||
# Output: (hidden1, hidden2, ...), (y1, y2, ...)
|
||||
# TODO: make sure there's a reasonable order for hidden states
|
||||
|
||||
function unrollgraph(v::IVertex, n)
|
||||
state, offset, default = collect_state(v)
|
||||
v = group(group(state...), v)
|
||||
steps = create_steps(v, n)
|
||||
for i = 1:n
|
||||
vars = inputs(steps[i][1])
|
||||
postwalk!(steps[i]) do v
|
||||
value(v) isa Offset || return v
|
||||
varid = findfirst(vars,v[1])
|
||||
getvar(varid, value(v).n + i, steps, offset, default)
|
||||
end
|
||||
end
|
||||
out = group(map(x->x[2], steps)...)
|
||||
state, defaults = stateout(steps, offset, default)
|
||||
group(state,out), defaults
|
||||
end
|
||||
|
||||
unrollgraph(m, n; kws...) = unrollgraph(atomise(m), n; kws...)
|
||||
|
||||
function unroll(model, n)
|
||||
graph, state = unrollgraph(model, n)
|
||||
SeqModel(Stateful(Capacitor(graph), state), n)
|
||||
end
|
||||
|
||||
function stateless(s::Stateful)
|
||||
v = graph(s.model)
|
||||
v = spliceinputs(v, group(constant.(s.states)...),
|
||||
[inputnode(i) for i = 1:graphinputs(v)-1]...)
|
||||
Capacitor(v[2])
|
||||
end
|
||||
|
||||
stateless(s::SeqModel) = SeqModel(stateless(s.model), s.steps)
|
||||
|
||||
function unseqin(v::IVertex)
|
||||
prewalk(v) do v
|
||||
# TODO: inputidx function
|
||||
isa(value(v), Split) && DataFlow.isinput(v[1]) && value(v[1]).n > 1 ? v[1] : v
|
||||
end
|
||||
end
|
||||
|
||||
unseqout(v::IVertex) = group(v[1], v[2][1])
|
||||
|
||||
unseq(graph) = unseqout(unseqin(graph))
|
||||
|
||||
function unroll1(model)
|
||||
graph, state = unrollgraph(model, 1)
|
||||
Stateful(Capacitor(unseq(graph)), state)
|
||||
end
|
||||
|
||||
flip(model) = Capacitor(map(x -> x isa Offset ? -x : x, atomise(model)))
|
@ -26,9 +26,6 @@ Optimise.children(c::Chain) = c.layers
|
||||
|
||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||
|
||||
Compiler.graph(s::Chain) =
|
||||
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
|
||||
|
||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||
|
||||
function Base.show(io::IO, c::Chain)
|
||||
|
@ -1,86 +0,0 @@
|
||||
using DataFlow, MacroTools
|
||||
using Flux: stack, unsqueeze
|
||||
using Flux.Compiler: @net, graph
|
||||
using DataFlow: Line, Frame
|
||||
|
||||
@net type Affine
|
||||
W
|
||||
b
|
||||
x -> x*W .+ b
|
||||
end
|
||||
|
||||
Affine(in::Integer, out::Integer; init = Flux.initn) =
|
||||
Affine(init(in, out), init(1, out))
|
||||
|
||||
@net type TLP
|
||||
first
|
||||
second
|
||||
function (x)
|
||||
l1 = σ.(first(x))
|
||||
l2 = softmax(second(l1))
|
||||
end
|
||||
end
|
||||
|
||||
@net type Recurrent
|
||||
Wxy; Wyy; by
|
||||
y
|
||||
function (x)
|
||||
y = tanh.( x * Wxy .+ y{-1} * Wyy .+ by )
|
||||
end
|
||||
end
|
||||
|
||||
Recurrent(in, out; init = Flux.initn) =
|
||||
Recurrent(init((in, out)), init((out, out)), init(1, out), init(1, out))
|
||||
|
||||
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
||||
syntax(x) = syntax(graph(x))
|
||||
|
||||
@testset "Compiler" begin
|
||||
|
||||
xs = randn(1, 10)
|
||||
d = Affine(10, 20)
|
||||
|
||||
@test d(xs) ≈ (xs*d.W + d.b)
|
||||
|
||||
d1 = @net x -> x * d.W + d.b
|
||||
|
||||
let
|
||||
@capture(syntax(d), _Frame(_Line((+).(x_[1] * W_, b_))))
|
||||
@test isa(x, DataFlow.Input) && W isa Array && b isa Array
|
||||
end
|
||||
|
||||
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
||||
tlp = TLP(a1, a2)
|
||||
@test tlp(xs) ≈ softmax(a2(σ.(a1(xs))))
|
||||
@test Flux.Compiler.interpmodel(tlp, xs) ≈ softmax(a2(σ.(a1(xs))))
|
||||
end
|
||||
|
||||
let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
||||
e = try
|
||||
Flux.Compiler.interpmodel(tlp, rand(1, 10))
|
||||
catch e
|
||||
e
|
||||
end
|
||||
@test e.trace[end].func == :TLP
|
||||
@test e.trace[end-1].func == Symbol("Affine")
|
||||
end
|
||||
|
||||
function apply(model, xs, state)
|
||||
ys = similar(xs, 0)
|
||||
for x in xs
|
||||
state, y = model(state, x)
|
||||
push!(ys, y)
|
||||
end
|
||||
state, ys
|
||||
end
|
||||
|
||||
@testset "RNN unrolling" begin
|
||||
r = Recurrent(10, 5)
|
||||
xs = [rand(1, 10) for _ = 1:3]
|
||||
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y,))
|
||||
@test ys[1] == tanh.(xs[1] * r.Wxy .+ r.y * r.Wyy .+ r.by)
|
||||
ru = Flux.Compiler.unroll(r, 3)
|
||||
ru(unsqueeze(stack(squeeze.(xs, 1), 1), 1))[1] == squeeze.(ys, 1)
|
||||
end
|
||||
|
||||
end
|
@ -2,7 +2,6 @@ using Flux, Base.Test
|
||||
|
||||
@testset "Flux" begin
|
||||
|
||||
include("compiler.jl")
|
||||
include("utils.jl")
|
||||
include("tracker.jl")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user