shape: handle tuples better
This commit is contained in:
parent
6756ce7528
commit
6237aa6739
@ -22,7 +22,7 @@ ihint(f, args...) = f(args...)
|
|||||||
hintify(ctx, c::Constant) = hintify(ctx, state(c.value))
|
hintify(ctx, c::Constant) = hintify(ctx, state(c.value))
|
||||||
hintify(ctx, xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
|
hintify(ctx, xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
|
||||||
|
|
||||||
interpshape = mux(ilinev, ihint, iargs, ituple, hintify)
|
interpshape = mux(ilinev, ihint, iargs, hintify)
|
||||||
|
|
||||||
function hintify(ctx, f, xs...)
|
function hintify(ctx, f, xs...)
|
||||||
sh = infer(f, map(gethint, xs)...)
|
sh = infer(f, map(gethint, xs)...)
|
||||||
@ -34,7 +34,7 @@ end
|
|||||||
function shapesv(f, args...)
|
function shapesv(f, args...)
|
||||||
(g = graph(f)) == nothing && return
|
(g = graph(f)) == nothing && return
|
||||||
ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)]
|
ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)]
|
||||||
interpret(Context(interpshape), g, ins...)
|
interpv(Context(interpshape), detuple(spliceinputs(g, ins...)))
|
||||||
end
|
end
|
||||||
|
|
||||||
shapes(args...) = shapesv(args...) |> syntax |> applylines |> (x->prettify(x, lines=true))
|
shapes(args...) = shapesv(args...) |> syntax |> applylines |> (x->prettify(x, lines=true))
|
||||||
@ -43,6 +43,8 @@ shapes(args...) = shapesv(args...) |> syntax |> applylines |> (x->prettify(x, li
|
|||||||
|
|
||||||
infer(f, args...) = graph(f) == nothing ? nothing : gethint(shapesv(f, args...))
|
infer(f, args...) = graph(f) == nothing ? nothing : gethint(shapesv(f, args...))
|
||||||
|
|
||||||
|
infer(::typeof(tuple), xs...) = (xs...,)
|
||||||
|
infer(s::Split, xs::Tuple) = 1 ≤ s.n ≤ length(xs) ? xs[s.n] : nothing
|
||||||
infer(::typeof(identity), x) = x
|
infer(::typeof(identity), x) = x
|
||||||
|
|
||||||
function infer(::typeof(*), a::Dims{2}, b::Dims{2})
|
function infer(::typeof(*), a::Dims{2}, b::Dims{2})
|
||||||
@ -50,8 +52,7 @@ function infer(::typeof(*), a::Dims{2}, b::Dims{2})
|
|||||||
(a[1], b[2])
|
(a[1], b[2])
|
||||||
end
|
end
|
||||||
|
|
||||||
# TODO: make correct
|
infer(::typeof(+), xs::Dims...) = Base.Broadcast.broadcast_shape(xs...)
|
||||||
infer(::typeof(+), a, b) = a
|
|
||||||
|
|
||||||
# Shapes macro
|
# Shapes macro
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user