From 0ed4e65d2f827960269f5009fae6d6a082321cd6 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 11 May 2017 17:32:14 +0100 Subject: [PATCH] refactor --- src/dims/catmat.jl | 8 +++- src/dims/iter.jl | 114 ++++++++++++++++++++++----------------------- 2 files changed, 64 insertions(+), 58 deletions(-) diff --git a/src/dims/catmat.jl b/src/dims/catmat.jl index abc241d8..b4c170e2 100644 --- a/src/dims/catmat.jl +++ b/src/dims/catmat.jl @@ -15,7 +15,13 @@ size(b::CatMat) = (size(b.data, 1),) getindex(b::CatMat, i)::eltype(b) = slicedim(b.data, 1, i) -setindex!(b::CatMat, v, i) = b.data[i, :] = v +setindex!(b::CatMat, v, i::Integer) = b.data[i, :] = v + +function setindex!(b::CatMat, xs, ::Colon) + for (i, x) in enumerate(xs) + b[i] = x + end +end allequal(xs) = all(x -> x == first(xs), xs) diff --git a/src/dims/iter.jl b/src/dims/iter.jl index 07b7672d..d35f14d3 100644 --- a/src/dims/iter.jl +++ b/src/dims/iter.jl @@ -1,72 +1,72 @@ export Batched -zipt(xs...) = (xs,) -zipt(xs::Tuple...) = zip(xs...) - import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length -mutable struct Batched{T,S} - batch::Int - iter::T - "`Batched` always read a batch in advance, and store it in `buf`" +# Stateful iteration + +mutable struct StatefulIter{I,S,T} + iter::I + state::S + next::Nullable{T} +end + +function StatefulIter(itr) + state = start(itr) + val, state = done(itr, state) ? (Nullable(), state) : next(itr, state) + return StatefulIter(itr, state, convert(Nullable, val)) +end + +peek(s::StatefulIter) = get(s.next) + +function Base.take!(s::StatefulIter) + x = peek(s) + if !done(s.iter, s.state) + s.next, s.state = next(s.iter, s.state) + else + s.next = Nullable() + end + return x +end + +Base.isempty(s::StatefulIter) = isnull(s.next) +Base.eltype(s::StatefulIter) = eltype(s.next) + +function taken!(s::StatefulIter, n::Integer) + xs = eltype(s)[] + for _ = 1:n + isempty(s) && break + push!(xs, take!(s)) + end + return xs +end + +# Batched + +struct Batched{I<:StatefulIter,S} + itr::I buf::S - i end -function Batched(iter::T, batch::Integer) where T - batch >= 1 || throw(ArgumentError("batch size must >= 1")) - i = start(iter) - done(iter, i) && return Batched{T,Void}(batch, iter, nothing, i) - v, i = next(iter, i) - - buf = mapt(v) do x - storage = Array{eltype(x)}(batch, size(x)...) - storage[1, :] = x - rebatch(storage) - end - - for ibatch in 2:batch - if done(iter, i) - warn("data less than one batch will be ignored, please use a smaller batch size") - return Batched{T,Void}(batch, iter, nothing, i) - end - - v, i = next(iter, i) - map(x->setindex!(x..., ibatch), zipt(buf, v)) - end - - Batched{T,typeof(buf)}(batch, iter, buf, i) +function Batched(itr, n::Integer) + n >= 1 || throw(ArgumentError("batch size must be >= 1")) + itr = StatefulIter(itr) + buf = rebatch(similar(eltype(itr)(), n, size(peek(itr))...)) + Batched(itr, buf) end -iteratoreltype(::Type{Batched{T,S}}) where {T,S} = Base.HasEltype() +iteratoreltype(::Type{<:Batched}) = Base.HasEltype() +iteratorsize(::Type{<:Batched}) = Base.SizeUnknown() -iteratorsize(::Type{Batched{T,S}}) where {T,S} = - iteratorsize(T) isa Base.HasShape ? - Base.HasLength() : iteratorsize(T) +eltype{T,S}(x::Batched{T,S}) = S -length(x::Batched) = length(x.iter) รท x.batch +start(::Batched) = () -eltype(x::Batched{T,S}) where {T,S} = S - -start(x::Batched) = true - -next(x::Batched, ::Bool) = x.buf, false +next(x::Batched, _) = x.buf, () # will be less hacky if https://github.com/JuliaLang/julia/issues/18823 -function done(x::Batched, fresh) - fresh && return false - - for ibatch in 1:x.batch - if done(x.iter, x.i) - ibatch != 1 && warn("cannot perfectly divide data by batch size, remainder will be discarded") - return true - end - - v, x.i = next(x.iter, x.i) - map(x->setindex!(x..., ibatch), zipt(x.buf, v)) - end - - false +function done(x::Batched, _) + next = taken!(x.itr, length(x.buf)) + length(next) < length(x.buf) && return true + x.buf[:] = next + return false end - -done(::Batched{T,Void}, ::Bool) where T = true