basics
This commit is contained in:
parent
70fbbf48fa
commit
0e0057b0c4
@ -1,5 +1,8 @@
|
||||
module JIT
|
||||
|
||||
include("shapes.jl")
|
||||
include("inplace.jl")
|
||||
include("trace.jl")
|
||||
include("lib.jl")
|
||||
|
||||
end
|
||||
|
11
src/jit/inplace.jl
Normal file
11
src/jit/inplace.jl
Normal file
@ -0,0 +1,11 @@
|
||||
mutable struct Cached{F,A}
|
||||
f::F
|
||||
buffer::A
|
||||
end
|
||||
|
||||
function (c::Cached)(args...)
|
||||
sh = shape(c.f, shape(args)...)
|
||||
bytes(sh) > length(c.buffer) && (c.buffer = similar(c.buffer, bytes(sh)))
|
||||
y = restructure(sh, c.buffer)
|
||||
inplace!(c.f, y, args...)
|
||||
end
|
9
src/jit/lib.jl
Normal file
9
src/jit/lib.jl
Normal file
@ -0,0 +1,9 @@
|
||||
# Primitive definitions
|
||||
|
||||
inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) =
|
||||
A_mul_B!(C, A, B)
|
||||
|
||||
shape(::typeof(broadcast), f, xs...) =
|
||||
Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...)
|
||||
|
||||
inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...)
|
37
src/jit/shapes.jl
Normal file
37
src/jit/shapes.jl
Normal file
@ -0,0 +1,37 @@
|
||||
struct Shape{T,N}
|
||||
dims::NTuple{N,Int}
|
||||
end
|
||||
|
||||
VecShape{T} = Shape{T,1}
|
||||
MatShape{T} = Shape{T,2}
|
||||
|
||||
Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims)
|
||||
|
||||
Base.size(s::Shape) = s.dims
|
||||
Base.size(s::Shape, n) = s.dims[n]
|
||||
Base.length(s::Shape) = prod(s.dims)
|
||||
Base.eltype(s::Shape{T}) where T = T
|
||||
|
||||
Base.sizeof(s::Shape{T}) where T = sizeof(T)*prod(size(s))
|
||||
|
||||
function Base.show(io::IO, s::Shape{T}) where T
|
||||
print(io, "Shape{$T}(")
|
||||
join(io, s.dims, ", ")
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
shape(x) = typeof(x)
|
||||
shape(x::Shape) = x
|
||||
shape(x::Tuple) = shape.(x)
|
||||
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
|
||||
|
||||
bytes(s::Shape) = sizeof(s)
|
||||
bytes(x::Tuple) = sum(bytes.(x))
|
||||
|
||||
# Recover structure from byte buffers
|
||||
# Make sure to hold on to the parent buffer for the lifetime of the data.
|
||||
|
||||
function restructure(sh::Shape{T}, buf::Vector{UInt8}) where T
|
||||
buf = unsafe_wrap(Array, pointer(buf), sizeof(sh))
|
||||
reshape(reinterpret(T, buf), size(sh))
|
||||
end
|
Loading…
Reference in New Issue
Block a user