shape: handle tuples better
This commit is contained in:
parent
6756ce7528
commit
6237aa6739
@ -22,7 +22,7 @@ ihint(f, args...) = f(args...)
|
||||
hintify(ctx, c::Constant) = hintify(ctx, state(c.value))
|
||||
hintify(ctx, xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
|
||||
|
||||
interpshape = mux(ilinev, ihint, iargs, ituple, hintify)
|
||||
interpshape = mux(ilinev, ihint, iargs, hintify)
|
||||
|
||||
function hintify(ctx, f, xs...)
|
||||
sh = infer(f, map(gethint, xs)...)
|
||||
@ -34,7 +34,7 @@ end
|
||||
function shapesv(f, args...)
|
||||
(g = graph(f)) == nothing && return
|
||||
ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)]
|
||||
interpret(Context(interpshape), g, ins...)
|
||||
interpv(Context(interpshape), detuple(spliceinputs(g, ins...)))
|
||||
end
|
||||
|
||||
shapes(args...) = shapesv(args...) |> syntax |> applylines |> (x->prettify(x, lines=true))
|
||||
@ -43,6 +43,8 @@ shapes(args...) = shapesv(args...) |> syntax |> applylines |> (x->prettify(x, li
|
||||
|
||||
infer(f, args...) = graph(f) == nothing ? nothing : gethint(shapesv(f, args...))
|
||||
|
||||
infer(::typeof(tuple), xs...) = (xs...,)
|
||||
infer(s::Split, xs::Tuple) = 1 ≤ s.n ≤ length(xs) ? xs[s.n] : nothing
|
||||
infer(::typeof(identity), x) = x
|
||||
|
||||
function infer(::typeof(*), a::Dims{2}, b::Dims{2})
|
||||
@ -50,8 +52,7 @@ function infer(::typeof(*), a::Dims{2}, b::Dims{2})
|
||||
(a[1], b[2])
|
||||
end
|
||||
|
||||
# TODO: make correct
|
||||
infer(::typeof(+), a, b) = a
|
||||
infer(::typeof(+), xs::Dims...) = Base.Broadcast.broadcast_shape(xs...)
|
||||
|
||||
# Shapes macro
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user