diff --git a/src/compiler/shape.jl b/src/compiler/shape.jl index d4b3f284..59d4c6f4 100644 --- a/src/compiler/shape.jl +++ b/src/compiler/shape.jl @@ -22,7 +22,7 @@ ihint(f, args...) = f(args...) hintify(ctx, c::Constant) = hintify(ctx, state(c.value)) hintify(ctx, xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_)) -interpshape = mux(ilinev, ihint, iargs, ituple, hintify) +interpshape = mux(ilinev, ihint, iargs, hintify) function hintify(ctx, f, xs...) sh = infer(f, map(gethint, xs)...) @@ -34,7 +34,7 @@ end function shapesv(f, args...) (g = graph(f)) == nothing && return ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)] - interpret(Context(interpshape), g, ins...) + interpv(Context(interpshape), detuple(spliceinputs(g, ins...))) end shapes(args...) = shapesv(args...) |> syntax |> applylines |> (x->prettify(x, lines=true)) @@ -43,6 +43,8 @@ shapes(args...) = shapesv(args...) |> syntax |> applylines |> (x->prettify(x, li infer(f, args...) = graph(f) == nothing ? nothing : gethint(shapesv(f, args...)) +infer(::typeof(tuple), xs...) = (xs...,) +infer(s::Split, xs::Tuple) = 1 ≤ s.n ≤ length(xs) ? xs[s.n] : nothing infer(::typeof(identity), x) = x function infer(::typeof(*), a::Dims{2}, b::Dims{2}) @@ -50,8 +52,7 @@ function infer(::typeof(*), a::Dims{2}, b::Dims{2}) (a[1], b[2]) end -# TODO: make correct -infer(::typeof(+), a, b) = a +infer(::typeof(+), xs::Dims...) = Base.Broadcast.broadcast_shape(xs...) # Shapes macro