remove inference, for now

This commit is contained in:
Mike J Innes 2017-08-18 01:15:30 +01:00
parent e79a1657d4
commit 052cc52ada
6 changed files with 0 additions and 109 deletions

View File

@ -25,7 +25,6 @@ include("params.jl")
include("compiler/code.jl")
include("compiler/loops.jl")
include("compiler/interp.jl")
include("compiler/shape.jl")
include("layers/control.jl")
include("layers/affine.jl")

View File

@ -1,74 +0,0 @@
Dims{N} = NTuple{N,Int}
struct Hint
typ
end
DataFlow.tocall(h::Hint, x) = :($x::$(h.typ))
arghint(p::Param) = arghint(state(p))
arghint(xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
arghint(x) = constant(x)
function gethint(v::IVertex)
while value(v) isa Union{Line,Frame} v = v[1] end
value(v) isa Hint && return value(v).typ
return
end
ihint(f, ctx::Context, h::Hint, x) = vertex(h, x)
ihint(f, args...) = f(args...)
function hintify(ctx, f, xs...)
xs = arghint.(xs)
sh = infer(f, map(gethint, xs)...)
sh nothing ? vertex(Hint(sh), vertex(f, xs...)) :
!any(x->x==nothing, xs) && graph(f) nothing ? interpret(Context(interpshape), graph(f), xs...) :
vertex(f, xs...)
end
interpshape = mux(ilinev, iconst, ihint, iargs, hintify)
function shapesv(f, args...)
(g = graph(f)) == nothing && return
ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)]
interpv(Context(interpshape), detuple(spliceinputs(g, ins...)))
end
shapes(args...) = shapesv(args...) |> syntax |> applylines |> (x->prettify(x, lines=true))
# Inference primitives
infer(f, args...) = graph(f) == nothing ? nothing : gethint(shapesv(f, args...))
infer(::typeof(tuple), xs...) = (xs...,)
infer(s::Split, xs::Tuple) = 1 s.n length(xs) ? xs[s.n] : nothing
infer(::typeof(identity), x) = x
function infer(::typeof(*), a::Dims{2}, b::Dims{2})
a[2] == b[1] || return nothing
(a[1], b[2])
end
infer(::typeof(broadcast), f, xs::Dims...) = Base.Broadcast.broadcast_shape(xs...)
# Old broadcast versions
infer(::typeof(.+), xs::Dims...) = Base.Broadcast.broadcast_shape(xs...)
# Shapes macro
macro shapes(ex)
@capture(ex, f_(args__)) || error("@shapes f(args...)")
:(shapes($(esc(f)), mapt(size, ($(map(esc, args)...),))...))
end
# Shim for kicking off shape inference
struct Input{N}
dims::Dims{N}
end
Input(i...) = Input((i...,))
(::Input)(x) = x
inferchain(f::Input, xs) = (-1, f.dims...)

View File

@ -9,10 +9,3 @@ back!(::typeof(relu), Δ, x) = Δ .* (x .> 0)
softmax(xs) = exp.(xs) ./ sum(exp.(xs), 2)
flatten(xs) = reshape(xs, size(xs, 1), :)
infer(::typeof(softmax), x) = x
infer(::typeof(tanh), x) = x
infer(::typeof(relu), x) = x
infer(::typeof(σ), x) = x
infer(::typeof(flatten), x::Dims) = (x[1], prod(x[2:end])...)

View File

@ -7,9 +7,6 @@ end
Affine(in::Integer, out::Integer; init = initn) =
Affine(init(in, out), init(1, out))
inferred(::Type{Affine}, in::Tuple{Dims{2}}, out::Integer) =
Affine(in[1][2], out)
function back!(m::Affine, Δ, x)
W, b = m.W, m.b
W.Δx[:] = x' * Δ

View File

@ -24,24 +24,3 @@ graph(s::Chain) =
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
# Chain Macros
inferred(f, in, args...; kws...) = f(args...; kws...)
# `inferchain` allows for overriding inference behaviour for convenience.
# For example, `infer(Affine(10, 20), nothing)` would normally return a shape
# error, but for the interface we just ignore any errors and return (1, 20).
inferchain(f, xs...) = infer(f, xs...)
macro Chain(x, xs...)
inferconstructor(x) =
@capture(x, f_(xs__)) ? :(inferred($(esc(f)), (shape,), $(esc.(xs)...))) : esc(x)
@q let
shape = nothing
c = Chain($(esc(x)))
$([:(shape = inferchain(c.layers[end], shape);
push!(c, $x)) for x in inferconstructor.(xs)]...)
c
end
end

View File

@ -10,8 +10,6 @@ d = Affine(10, 20)
d1 = @net x -> x * d.W + d.b
@test Flux.infer(d, (1, 10)) == (1,20)
# Skip this before new DataFlow is released.
# let
# @test @capture(syntax(d), _Frame(_Line((+).(x_[1] * W_, b_))))
@ -24,7 +22,6 @@ let a1 = Affine(10, 20), a2 = Affine(20, 15)
tlp = TLP(a1, a2)
@test tlp(xs) softmax(a2(σ(a1(xs))))
@test Flux.interpmodel(tlp, xs) softmax(a2(σ(a1(xs))))
@test Flux.infer(tlp, (1, 10)) == (1,15)
end
let tlp = TLP(Affine(10, 21), Affine(20, 15))