diff --git a/src/Flux.jl b/src/Flux.jl index c42188f5..a91846da 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -25,7 +25,6 @@ include("params.jl") include("compiler/code.jl") include("compiler/loops.jl") include("compiler/interp.jl") -include("compiler/shape.jl") include("layers/control.jl") include("layers/affine.jl") diff --git a/src/compiler/shape.jl b/src/compiler/shape.jl deleted file mode 100644 index 1b7fb41f..00000000 --- a/src/compiler/shape.jl +++ /dev/null @@ -1,74 +0,0 @@ -Dims{N} = NTuple{N,Int} - -struct Hint - typ -end - -DataFlow.tocall(h::Hint, x) = :($x::$(h.typ)) - -arghint(p::Param) = arghint(state(p)) -arghint(xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_)) -arghint(x) = constant(x) - -function gethint(v::IVertex) - while value(v) isa Union{Line,Frame} v = v[1] end - value(v) isa Hint && return value(v).typ - return -end - -ihint(f, ctx::Context, h::Hint, x) = vertex(h, x) -ihint(f, args...) = f(args...) - -function hintify(ctx, f, xs...) - xs = arghint.(xs) - sh = infer(f, map(gethint, xs)...) - sh ≠ nothing ? vertex(Hint(sh), vertex(f, xs...)) : - !any(x->x==nothing, xs) && graph(f) ≠ nothing ? interpret(Context(interpshape), graph(f), xs...) : - vertex(f, xs...) -end - -interpshape = mux(ilinev, iconst, ihint, iargs, hintify) - -function shapesv(f, args...) - (g = graph(f)) == nothing && return - ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)] - interpv(Context(interpshape), detuple(spliceinputs(g, ins...))) -end - -shapes(args...) = shapesv(args...) |> syntax |> applylines |> (x->prettify(x, lines=true)) - -# Inference primitives - -infer(f, args...) = graph(f) == nothing ? nothing : gethint(shapesv(f, args...)) - -infer(::typeof(tuple), xs...) = (xs...,) -infer(s::Split, xs::Tuple) = 1 ≤ s.n ≤ length(xs) ? xs[s.n] : nothing -infer(::typeof(identity), x) = x - -function infer(::typeof(*), a::Dims{2}, b::Dims{2}) - a[2] == b[1] || return nothing - (a[1], b[2]) -end - -infer(::typeof(broadcast), f, xs::Dims...) = Base.Broadcast.broadcast_shape(xs...) -# Old broadcast versions -infer(::typeof(.+), xs::Dims...) = Base.Broadcast.broadcast_shape(xs...) - -# Shapes macro - -macro shapes(ex) - @capture(ex, f_(args__)) || error("@shapes f(args...)") - :(shapes($(esc(f)), mapt(size, ($(map(esc, args)...),))...)) -end - -# Shim for kicking off shape inference - -struct Input{N} - dims::Dims{N} -end - -Input(i...) = Input((i...,)) - -(::Input)(x) = x - -inferchain(f::Input, xs) = (-1, f.dims...) diff --git a/src/layers/activation.jl b/src/layers/activation.jl index d22ce7d7..7d0d27ad 100644 --- a/src/layers/activation.jl +++ b/src/layers/activation.jl @@ -9,10 +9,3 @@ back!(::typeof(relu), Δ, x) = Δ .* (x .> 0) softmax(xs) = exp.(xs) ./ sum(exp.(xs), 2) flatten(xs) = reshape(xs, size(xs, 1), :) - -infer(::typeof(softmax), x) = x -infer(::typeof(tanh), x) = x -infer(::typeof(relu), x) = x -infer(::typeof(σ), x) = x - -infer(::typeof(flatten), x::Dims) = (x[1], prod(x[2:end])...) diff --git a/src/layers/affine.jl b/src/layers/affine.jl index ca79c004..c6586cf8 100644 --- a/src/layers/affine.jl +++ b/src/layers/affine.jl @@ -7,9 +7,6 @@ end Affine(in::Integer, out::Integer; init = initn) = Affine(init(in, out), init(1, out)) -inferred(::Type{Affine}, in::Tuple{Dims{2}}, out::Integer) = - Affine(in[1][2], out) - function back!(m::Affine, Δ, x) W, b = m.W, m.b W.Δx[:] = x' * Δ diff --git a/src/layers/control.jl b/src/layers/control.jl index d0c5e61b..29700f38 100644 --- a/src/layers/control.jl +++ b/src/layers/control.jl @@ -24,24 +24,3 @@ graph(s::Chain) = foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) - -# Chain Macros - -inferred(f, in, args...; kws...) = f(args...; kws...) - -# `inferchain` allows for overriding inference behaviour for convenience. -# For example, `infer(Affine(10, 20), nothing)` would normally return a shape -# error, but for the interface we just ignore any errors and return (1, 20). -inferchain(f, xs...) = infer(f, xs...) - -macro Chain(x, xs...) - inferconstructor(x) = - @capture(x, f_(xs__)) ? :(inferred($(esc(f)), (shape,), $(esc.(xs)...))) : esc(x) - @q let - shape = nothing - c = Chain($(esc(x))) - $([:(shape = inferchain(c.layers[end], shape); - push!(c, $x)) for x in inferconstructor.(xs)]...) - c - end -end diff --git a/test/basic.jl b/test/basic.jl index 93d77275..d2b28d3c 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -10,8 +10,6 @@ d = Affine(10, 20) d1 = @net x -> x * d.W + d.b -@test Flux.infer(d, (1, 10)) == (1,20) - # Skip this before new DataFlow is released. # let # @test @capture(syntax(d), _Frame(_Line((+).(x_[1] * W_, b_)))) @@ -24,7 +22,6 @@ let a1 = Affine(10, 20), a2 = Affine(20, 15) tlp = TLP(a1, a2) @test tlp(xs) ≈ softmax(a2(σ(a1(xs)))) @test Flux.interpmodel(tlp, xs) ≈ softmax(a2(σ(a1(xs)))) - @test Flux.infer(tlp, (1, 10)) == (1,15) end let tlp = TLP(Affine(10, 21), Affine(20, 15))