Flux.jl/src/dims/catmat.jl

51 lines
1.1 KiB
Julia
Raw Normal View History

2016-10-25 12:48:30 +00:00
import Base: eltype, size, getindex, setindex!, convert
2016-10-25 13:10:32 +00:00
export CatMat, rawbatch
2016-10-25 12:48:30 +00:00
immutable CatMat{T,S} <: AbstractVector{T}
data::S
end
2016-10-25 13:10:27 +00:00
convert{T,S}(::Type{CatMat{T,S}},storage::S) =
CatMat{T,S}(storage)
2016-10-25 12:48:30 +00:00
eltype{T}(::CatMat{T}) = T
size(b::CatMat) = (size(b.data, 1),)
getindex(b::CatMat, i)::eltype(b) = slicedim(b.data, 1, i)
setindex!(b::CatMat, v, i) = b[i, :] = v
allequal(xs) = all(x -> x == first(xs), xs)
function (::Type{CatMat{T,S}}){T,S}(xs, storage::S)
@assert @>> xs map(size) allequal
@assert size(storage) == (length(xs), size(first(xs))...)
for i = 1:length(xs)
storage[i, :] = xs[i]
end
return CatMat{T,S}(storage)
end
function (::Type{CatMat{T}}){T}(xs)
xs = map(rawbatch, xs)
storage = similar(first(xs), (length(xs), size(first(xs))...))
CatMat{T,typeof(storage)}(xs, storage)
end
function CatMat(xs)
xs = promote(xs...)
CatMat{eltype(xs)}(xs)
end
@render Juno.Inline b::CatMat begin
Tree(Row(Text("CatMat of "), eltype(b),
Juno.fade("[$(length(b))]")),
Juno.trim(collect(b)))
end
rawbatch(xs) = xs
rawbatch(xs::CatMat) = xs.data