diff --git a/src/Batches/iter.jl b/src/Batches/iter.jl index 70673739..27fd6c26 100644 --- a/src/Batches/iter.jl +++ b/src/Batches/iter.jl @@ -1,4 +1,10 @@ -import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length +# Simple version + +using Base.Iterators: partition + +partitionr(xs, n) = take(partition(xs, n), length(xs)÷n) + +batches(xs, n) = (Batch([xs...]) for xs in partitionr(xs, n)) # Stateful iteration @@ -40,6 +46,8 @@ end # Batched +import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length + struct Batched{I<:StatefulIter,S} itr::I buf::S