From 1bd0a43b7d4d4532028b41958f47412384657f2e Mon Sep 17 00:00:00 2001 From: ylxdzsw Date: Thu, 11 May 2017 15:47:19 +0100 Subject: [PATCH] batch iterator --- src/Flux.jl | 1 + src/dims/iter.jl | 72 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 src/dims/iter.jl diff --git a/src/Flux.jl b/src/Flux.jl index d8d20f53..c72b3171 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -19,6 +19,7 @@ include("model.jl") include("dims/catmat.jl") include("dims/batching.jl") include("dims/seq.jl") +include("dims/iter.jl") include("compiler/code.jl") include("compiler/loops.jl") diff --git a/src/dims/iter.jl b/src/dims/iter.jl new file mode 100644 index 00000000..07b7672d --- /dev/null +++ b/src/dims/iter.jl @@ -0,0 +1,72 @@ +export Batched + +zipt(xs...) = (xs,) +zipt(xs::Tuple...) = zip(xs...) + +import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length + +mutable struct Batched{T,S} + batch::Int + iter::T + "`Batched` always read a batch in advance, and store it in `buf`" + buf::S + i +end + +function Batched(iter::T, batch::Integer) where T + batch >= 1 || throw(ArgumentError("batch size must >= 1")) + i = start(iter) + done(iter, i) && return Batched{T,Void}(batch, iter, nothing, i) + v, i = next(iter, i) + + buf = mapt(v) do x + storage = Array{eltype(x)}(batch, size(x)...) + storage[1, :] = x + rebatch(storage) + end + + for ibatch in 2:batch + if done(iter, i) + warn("data less than one batch will be ignored, please use a smaller batch size") + return Batched{T,Void}(batch, iter, nothing, i) + end + + v, i = next(iter, i) + map(x->setindex!(x..., ibatch), zipt(buf, v)) + end + + Batched{T,typeof(buf)}(batch, iter, buf, i) +end + +iteratoreltype(::Type{Batched{T,S}}) where {T,S} = Base.HasEltype() + +iteratorsize(::Type{Batched{T,S}}) where {T,S} = + iteratorsize(T) isa Base.HasShape ? + Base.HasLength() : iteratorsize(T) + +length(x::Batched) = length(x.iter) รท x.batch + +eltype(x::Batched{T,S}) where {T,S} = S + +start(x::Batched) = true + +next(x::Batched, ::Bool) = x.buf, false + +# will be less hacky if https://github.com/JuliaLang/julia/issues/18823 +function done(x::Batched, fresh) + fresh && return false + + for ibatch in 1:x.batch + if done(x.iter, x.i) + ibatch != 1 && warn("cannot perfectly divide data by batch size, remainder will be discarded") + return true + end + + v, x.i = next(x.iter, x.i) + map(x->setindex!(x..., ibatch), zipt(x.buf, v)) + end + + false +end + +done(::Batched{T,Void}, ::Bool) where T = true