diff --git a/src/jit/JIT.jl b/src/jit/JIT.jl index 3283005a..2f30bb0c 100644 --- a/src/jit/JIT.jl +++ b/src/jit/JIT.jl @@ -1,5 +1,8 @@ module JIT +include("shapes.jl") +include("inplace.jl") include("trace.jl") +include("lib.jl") end diff --git a/src/jit/inplace.jl b/src/jit/inplace.jl new file mode 100644 index 00000000..cf87b102 --- /dev/null +++ b/src/jit/inplace.jl @@ -0,0 +1,11 @@ +mutable struct Cached{F,A} + f::F + buffer::A +end + +function (c::Cached)(args...) + sh = shape(c.f, shape(args)...) + bytes(sh) > length(c.buffer) && (c.buffer = similar(c.buffer, bytes(sh))) + y = restructure(sh, c.buffer) + inplace!(c.f, y, args...) +end diff --git a/src/jit/lib.jl b/src/jit/lib.jl new file mode 100644 index 00000000..cc89fa00 --- /dev/null +++ b/src/jit/lib.jl @@ -0,0 +1,9 @@ +# Primitive definitions + +inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) = + A_mul_B!(C, A, B) + +shape(::typeof(broadcast), f, xs...) = + Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...) + +inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...) diff --git a/src/jit/shapes.jl b/src/jit/shapes.jl new file mode 100644 index 00000000..bd6f4993 --- /dev/null +++ b/src/jit/shapes.jl @@ -0,0 +1,37 @@ +struct Shape{T,N} + dims::NTuple{N,Int} +end + +VecShape{T} = Shape{T,1} +MatShape{T} = Shape{T,2} + +Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims) + +Base.size(s::Shape) = s.dims +Base.size(s::Shape, n) = s.dims[n] +Base.length(s::Shape) = prod(s.dims) +Base.eltype(s::Shape{T}) where T = T + +Base.sizeof(s::Shape{T}) where T = sizeof(T)*prod(size(s)) + +function Base.show(io::IO, s::Shape{T}) where T + print(io, "Shape{$T}(") + join(io, s.dims, ", ") + print(io, ")") +end + +shape(x) = typeof(x) +shape(x::Shape) = x +shape(x::Tuple) = shape.(x) +shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...) + +bytes(s::Shape) = sizeof(s) +bytes(x::Tuple) = sum(bytes.(x)) + +# Recover structure from byte buffers +# Make sure to hold on to the parent buffer for the lifetime of the data. + +function restructure(sh::Shape{T}, buf::Vector{UInt8}) where T + buf = unsafe_wrap(Array, pointer(buf), sizeof(sh)) + reshape(reinterpret(T, buf), size(sh)) +end