diff --git a/src/compiler/shape.jl b/src/compiler/shape.jl index 0b936b84..1e6ee0e0 100644 --- a/src/compiler/shape.jl +++ b/src/compiler/shape.jl @@ -1,5 +1,7 @@ using DataFlow.Interpreter +export @shapes + type Hint typ end @@ -46,3 +48,13 @@ end # TODO: make correct infer(::typeof(+), a, b) = a + +# Shapes macro + +_shape(xs::AbstractArray) = size(xs) +_shape(xs::Tuple) = map(_shape, xs) + +macro shapes(ex) + @capture(ex, f_(args__)) || error("@shapes f(args...)") + :(shapes($(esc(f)), _shape(($(map(esc, args)...),))...)) +end