Flux.jl/src/jit/shapes.jl
2018-03-16 14:42:08 +00:00

57 lines
1.4 KiB
Julia

using ..Tracker: TrackedArray
struct Shape{T,N}
dims::NTuple{N,Int}
end
VecShape{T} = Shape{T,1}
MatShape{T} = Shape{T,2}
Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims)
Shape{T}(dims::NTuple{N,Integer}) where {T,N} = Shape{T,N}(dims)
Base.size(s::Shape) = s.dims
Base.size(s::Shape, n) = s.dims[n]
Base.ndims(s::Shape{T,N}) where {T,N} = N
Base.length(s::Shape) = prod(s.dims)
Base.eltype(s::Shape{T}) where T = T
Base.sizeof(s::Shape{T}) where T = sizeof(T)*prod(size(s))
function Base.show(io::IO, s::Shape{T}) where T
print(io, "Shape{$T}(")
join(io, s.dims, ", ")
print(io, ")")
end
shape(x) = x
shape(x::Shape) = x
shape(x::Tuple) = shape.(x)
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
shape(x::TrackedArray) = shape(x.data)
bytes(s::Shape) = sizeof(s)
bytes(x::Tuple) = sum(bytes.(x))
# Recover structure from byte buffers
# Make sure to hold on to the parent buffer for the lifetime of the data.
function restructure(sh::Shape{T}, buf::Vector{UInt8}) where T
buf = unsafe_wrap(Array, pointer(buf), sizeof(sh))
reshape(reinterpret(T, buf), size(sh))
end
# Execution with caches
mutable struct Cached{F,A}
f::F
buffer::A
end
function (c::Cached)(args...)
sh = shape(c.f, shape(args)...)
bytes(sh) > length(c.buffer) && (c.buffer = similar(c.buffer, bytes(sh)))
y = restructure(sh, c.buffer)
inplace!(c.f, y, args...)
end