diff --git a/src/jit/shapes.jl b/src/jit/shapes.jl index 1443c4f0..45998b0a 100644 --- a/src/jit/shapes.jl +++ b/src/jit/shapes.jl @@ -8,6 +8,7 @@ VecShape{T} = Shape{T,1} MatShape{T} = Shape{T,2} Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims) +Shape{T}(dims::NTuple{N,Integer}) where {T,N} = Shape{T,N}(dims) Base.size(s::Shape) = s.dims Base.size(s::Shape, n) = s.dims[n] @@ -22,7 +23,7 @@ function Base.show(io::IO, s::Shape{T}) where T print(io, ")") end -shape(x) = typeof(x) +shape(x) = x shape(x::Shape) = x shape(x::Tuple) = shape.(x) shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...) diff --git a/src/jit/trace.jl b/src/jit/trace.jl index 565d7724..01bf0afc 100644 --- a/src/jit/trace.jl +++ b/src/jit/trace.jl @@ -53,9 +53,12 @@ struct Compiled{F,T<:Tuple} params::T end -(c::Compiled)(args...) = - Tracker.track(Tracker.Call(c, args...), - c.func(Tracker.data.(c.params), args...)) +# TODO when we support derivatives +# (c::Compiled)(args...) = +# Tracker.track(Tracker.Call(c, args...), +# c.func(Tracker.data.(c.params), args...)) + +(c::Compiled)(args...) = c.func(Tracker.data.(c.params), Tracker.data.(args)...) Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")