refactor
This commit is contained in:
parent
1bd0a43b7d
commit
0ed4e65d2f
@ -15,7 +15,13 @@ size(b::CatMat) = (size(b.data, 1),)
|
|||||||
|
|
||||||
getindex(b::CatMat, i)::eltype(b) = slicedim(b.data, 1, i)
|
getindex(b::CatMat, i)::eltype(b) = slicedim(b.data, 1, i)
|
||||||
|
|
||||||
setindex!(b::CatMat, v, i) = b.data[i, :] = v
|
setindex!(b::CatMat, v, i::Integer) = b.data[i, :] = v
|
||||||
|
|
||||||
|
function setindex!(b::CatMat, xs, ::Colon)
|
||||||
|
for (i, x) in enumerate(xs)
|
||||||
|
b[i] = x
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
allequal(xs) = all(x -> x == first(xs), xs)
|
allequal(xs) = all(x -> x == first(xs), xs)
|
||||||
|
|
||||||
|
114
src/dims/iter.jl
114
src/dims/iter.jl
@ -1,72 +1,72 @@
|
|||||||
export Batched
|
export Batched
|
||||||
|
|
||||||
zipt(xs...) = (xs,)
|
|
||||||
zipt(xs::Tuple...) = zip(xs...)
|
|
||||||
|
|
||||||
import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length
|
import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length
|
||||||
|
|
||||||
mutable struct Batched{T,S}
|
# Stateful iteration
|
||||||
batch::Int
|
|
||||||
iter::T
|
mutable struct StatefulIter{I,S,T}
|
||||||
"`Batched` always read a batch in advance, and store it in `buf`"
|
iter::I
|
||||||
|
state::S
|
||||||
|
next::Nullable{T}
|
||||||
|
end
|
||||||
|
|
||||||
|
function StatefulIter(itr)
|
||||||
|
state = start(itr)
|
||||||
|
val, state = done(itr, state) ? (Nullable(), state) : next(itr, state)
|
||||||
|
return StatefulIter(itr, state, convert(Nullable, val))
|
||||||
|
end
|
||||||
|
|
||||||
|
peek(s::StatefulIter) = get(s.next)
|
||||||
|
|
||||||
|
function Base.take!(s::StatefulIter)
|
||||||
|
x = peek(s)
|
||||||
|
if !done(s.iter, s.state)
|
||||||
|
s.next, s.state = next(s.iter, s.state)
|
||||||
|
else
|
||||||
|
s.next = Nullable()
|
||||||
|
end
|
||||||
|
return x
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.isempty(s::StatefulIter) = isnull(s.next)
|
||||||
|
Base.eltype(s::StatefulIter) = eltype(s.next)
|
||||||
|
|
||||||
|
function taken!(s::StatefulIter, n::Integer)
|
||||||
|
xs = eltype(s)[]
|
||||||
|
for _ = 1:n
|
||||||
|
isempty(s) && break
|
||||||
|
push!(xs, take!(s))
|
||||||
|
end
|
||||||
|
return xs
|
||||||
|
end
|
||||||
|
|
||||||
|
# Batched
|
||||||
|
|
||||||
|
struct Batched{I<:StatefulIter,S}
|
||||||
|
itr::I
|
||||||
buf::S
|
buf::S
|
||||||
i
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function Batched(iter::T, batch::Integer) where T
|
function Batched(itr, n::Integer)
|
||||||
batch >= 1 || throw(ArgumentError("batch size must >= 1"))
|
n >= 1 || throw(ArgumentError("batch size must be >= 1"))
|
||||||
i = start(iter)
|
itr = StatefulIter(itr)
|
||||||
done(iter, i) && return Batched{T,Void}(batch, iter, nothing, i)
|
buf = rebatch(similar(eltype(itr)(), n, size(peek(itr))...))
|
||||||
v, i = next(iter, i)
|
Batched(itr, buf)
|
||||||
|
|
||||||
buf = mapt(v) do x
|
|
||||||
storage = Array{eltype(x)}(batch, size(x)...)
|
|
||||||
storage[1, :] = x
|
|
||||||
rebatch(storage)
|
|
||||||
end
|
|
||||||
|
|
||||||
for ibatch in 2:batch
|
|
||||||
if done(iter, i)
|
|
||||||
warn("data less than one batch will be ignored, please use a smaller batch size")
|
|
||||||
return Batched{T,Void}(batch, iter, nothing, i)
|
|
||||||
end
|
|
||||||
|
|
||||||
v, i = next(iter, i)
|
|
||||||
map(x->setindex!(x..., ibatch), zipt(buf, v))
|
|
||||||
end
|
|
||||||
|
|
||||||
Batched{T,typeof(buf)}(batch, iter, buf, i)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
iteratoreltype(::Type{Batched{T,S}}) where {T,S} = Base.HasEltype()
|
iteratoreltype(::Type{<:Batched}) = Base.HasEltype()
|
||||||
|
iteratorsize(::Type{<:Batched}) = Base.SizeUnknown()
|
||||||
|
|
||||||
iteratorsize(::Type{Batched{T,S}}) where {T,S} =
|
eltype{T,S}(x::Batched{T,S}) = S
|
||||||
iteratorsize(T) isa Base.HasShape ?
|
|
||||||
Base.HasLength() : iteratorsize(T)
|
|
||||||
|
|
||||||
length(x::Batched) = length(x.iter) ÷ x.batch
|
start(::Batched) = ()
|
||||||
|
|
||||||
eltype(x::Batched{T,S}) where {T,S} = S
|
next(x::Batched, _) = x.buf, ()
|
||||||
|
|
||||||
start(x::Batched) = true
|
|
||||||
|
|
||||||
next(x::Batched, ::Bool) = x.buf, false
|
|
||||||
|
|
||||||
# will be less hacky if https://github.com/JuliaLang/julia/issues/18823
|
# will be less hacky if https://github.com/JuliaLang/julia/issues/18823
|
||||||
function done(x::Batched, fresh)
|
function done(x::Batched, _)
|
||||||
fresh && return false
|
next = taken!(x.itr, length(x.buf))
|
||||||
|
length(next) < length(x.buf) && return true
|
||||||
for ibatch in 1:x.batch
|
x.buf[:] = next
|
||||||
if done(x.iter, x.i)
|
return false
|
||||||
ibatch != 1 && warn("cannot perfectly divide data by batch size, remainder will be discarded")
|
|
||||||
return true
|
|
||||||
end
|
|
||||||
|
|
||||||
v, x.i = next(x.iter, x.i)
|
|
||||||
map(x->setindex!(x..., ibatch), zipt(x.buf, v))
|
|
||||||
end
|
|
||||||
|
|
||||||
false
|
|
||||||
end
|
end
|
||||||
|
|
||||||
done(::Batched{T,Void}, ::Bool) where T = true
|
|
||||||
|
Loading…
Reference in New Issue
Block a user