shape: handle tuples better

This commit is contained in:
Mike J Innes 2017-03-20 23:10:38 +00:00
parent 6756ce7528
commit 6237aa6739

View File

@ -22,7 +22,7 @@ ihint(f, args...) = f(args...)
hintify(ctx, c::Constant) = hintify(ctx, state(c.value))
hintify(ctx, xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
interpshape = mux(ilinev, ihint, iargs, ituple, hintify)
interpshape = mux(ilinev, ihint, iargs, hintify)
function hintify(ctx, f, xs...)
sh = infer(f, map(gethint, xs)...)
@ -34,7 +34,7 @@ end
function shapesv(f, args...)
(g = graph(f)) == nothing && return
ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)]
interpret(Context(interpshape), g, ins...)
interpv(Context(interpshape), detuple(spliceinputs(g, ins...)))
end
shapes(args...) = shapesv(args...) |> syntax |> applylines |> (x->prettify(x, lines=true))
@ -43,6 +43,8 @@ shapes(args...) = shapesv(args...) |> syntax |> applylines |> (x->prettify(x, li
infer(f, args...) = graph(f) == nothing ? nothing : gethint(shapesv(f, args...))
infer(::typeof(tuple), xs...) = (xs...,)
infer(s::Split, xs::Tuple) = 1 s.n length(xs) ? xs[s.n] : nothing
infer(::typeof(identity), x) = x
function infer(::typeof(*), a::Dims{2}, b::Dims{2})
@ -50,8 +52,7 @@ function infer(::typeof(*), a::Dims{2}, b::Dims{2})
(a[1], b[2])
end
# TODO: make correct
infer(::typeof(+), a, b) = a
infer(::typeof(+), xs::Dims...) = Base.Broadcast.broadcast_shape(xs...)
# Shapes macro