diff --git a/src/Flux.jl b/src/Flux.jl index d8d20f53..c72b3171 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -19,6 +19,7 @@ include("model.jl") include("dims/catmat.jl") include("dims/batching.jl") include("dims/seq.jl") +include("dims/iter.jl") include("compiler/code.jl") include("compiler/loops.jl") diff --git a/src/dims/catmat.jl b/src/dims/catmat.jl index abc241d8..b4c170e2 100644 --- a/src/dims/catmat.jl +++ b/src/dims/catmat.jl @@ -15,7 +15,13 @@ size(b::CatMat) = (size(b.data, 1),) getindex(b::CatMat, i)::eltype(b) = slicedim(b.data, 1, i) -setindex!(b::CatMat, v, i) = b.data[i, :] = v +setindex!(b::CatMat, v, i::Integer) = b.data[i, :] = v + +function setindex!(b::CatMat, xs, ::Colon) + for (i, x) in enumerate(xs) + b[i] = x + end +end allequal(xs) = all(x -> x == first(xs), xs) diff --git a/src/dims/iter.jl b/src/dims/iter.jl new file mode 100644 index 00000000..d35f14d3 --- /dev/null +++ b/src/dims/iter.jl @@ -0,0 +1,72 @@ +export Batched + +import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length + +# Stateful iteration + +mutable struct StatefulIter{I,S,T} + iter::I + state::S + next::Nullable{T} +end + +function StatefulIter(itr) + state = start(itr) + val, state = done(itr, state) ? (Nullable(), state) : next(itr, state) + return StatefulIter(itr, state, convert(Nullable, val)) +end + +peek(s::StatefulIter) = get(s.next) + +function Base.take!(s::StatefulIter) + x = peek(s) + if !done(s.iter, s.state) + s.next, s.state = next(s.iter, s.state) + else + s.next = Nullable() + end + return x +end + +Base.isempty(s::StatefulIter) = isnull(s.next) +Base.eltype(s::StatefulIter) = eltype(s.next) + +function taken!(s::StatefulIter, n::Integer) + xs = eltype(s)[] + for _ = 1:n + isempty(s) && break + push!(xs, take!(s)) + end + return xs +end + +# Batched + +struct Batched{I<:StatefulIter,S} + itr::I + buf::S +end + +function Batched(itr, n::Integer) + n >= 1 || throw(ArgumentError("batch size must be >= 1")) + itr = StatefulIter(itr) + buf = rebatch(similar(eltype(itr)(), n, size(peek(itr))...)) + Batched(itr, buf) +end + +iteratoreltype(::Type{<:Batched}) = Base.HasEltype() +iteratorsize(::Type{<:Batched}) = Base.SizeUnknown() + +eltype{T,S}(x::Batched{T,S}) = S + +start(::Batched) = () + +next(x::Batched, _) = x.buf, () + +# will be less hacky if https://github.com/JuliaLang/julia/issues/18823 +function done(x::Batched, _) + next = taken!(x.itr, length(x.buf)) + length(next) < length(x.buf) && return true + x.buf[:] = next + return false +end