commit
cb4d8cf9a6
@ -19,6 +19,7 @@ include("model.jl")
|
||||
include("dims/catmat.jl")
|
||||
include("dims/batching.jl")
|
||||
include("dims/seq.jl")
|
||||
include("dims/iter.jl")
|
||||
|
||||
include("compiler/code.jl")
|
||||
include("compiler/loops.jl")
|
||||
|
@ -15,7 +15,13 @@ size(b::CatMat) = (size(b.data, 1),)
|
||||
|
||||
getindex(b::CatMat, i)::eltype(b) = slicedim(b.data, 1, i)
|
||||
|
||||
setindex!(b::CatMat, v, i) = b.data[i, :] = v
|
||||
setindex!(b::CatMat, v, i::Integer) = b.data[i, :] = v
|
||||
|
||||
function setindex!(b::CatMat, xs, ::Colon)
|
||||
for (i, x) in enumerate(xs)
|
||||
b[i] = x
|
||||
end
|
||||
end
|
||||
|
||||
allequal(xs) = all(x -> x == first(xs), xs)
|
||||
|
||||
|
72
src/dims/iter.jl
Normal file
72
src/dims/iter.jl
Normal file
@ -0,0 +1,72 @@
|
||||
export Batched
|
||||
|
||||
import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length
|
||||
|
||||
# Stateful iteration
|
||||
|
||||
mutable struct StatefulIter{I,S,T}
|
||||
iter::I
|
||||
state::S
|
||||
next::Nullable{T}
|
||||
end
|
||||
|
||||
function StatefulIter(itr)
|
||||
state = start(itr)
|
||||
val, state = done(itr, state) ? (Nullable(), state) : next(itr, state)
|
||||
return StatefulIter(itr, state, convert(Nullable, val))
|
||||
end
|
||||
|
||||
peek(s::StatefulIter) = get(s.next)
|
||||
|
||||
function Base.take!(s::StatefulIter)
|
||||
x = peek(s)
|
||||
if !done(s.iter, s.state)
|
||||
s.next, s.state = next(s.iter, s.state)
|
||||
else
|
||||
s.next = Nullable()
|
||||
end
|
||||
return x
|
||||
end
|
||||
|
||||
Base.isempty(s::StatefulIter) = isnull(s.next)
|
||||
Base.eltype(s::StatefulIter) = eltype(s.next)
|
||||
|
||||
function taken!(s::StatefulIter, n::Integer)
|
||||
xs = eltype(s)[]
|
||||
for _ = 1:n
|
||||
isempty(s) && break
|
||||
push!(xs, take!(s))
|
||||
end
|
||||
return xs
|
||||
end
|
||||
|
||||
# Batched
|
||||
|
||||
struct Batched{I<:StatefulIter,S}
|
||||
itr::I
|
||||
buf::S
|
||||
end
|
||||
|
||||
function Batched(itr, n::Integer)
|
||||
n >= 1 || throw(ArgumentError("batch size must be >= 1"))
|
||||
itr = StatefulIter(itr)
|
||||
buf = rebatch(similar(eltype(itr)(), n, size(peek(itr))...))
|
||||
Batched(itr, buf)
|
||||
end
|
||||
|
||||
iteratoreltype(::Type{<:Batched}) = Base.HasEltype()
|
||||
iteratorsize(::Type{<:Batched}) = Base.SizeUnknown()
|
||||
|
||||
eltype{T,S}(x::Batched{T,S}) = S
|
||||
|
||||
start(::Batched) = ()
|
||||
|
||||
next(x::Batched, _) = x.buf, ()
|
||||
|
||||
# will be less hacky if https://github.com/JuliaLang/julia/issues/18823
|
||||
function done(x::Batched, _)
|
||||
next = taken!(x.itr, length(x.buf))
|
||||
length(next) < length(x.buf) && return true
|
||||
x.buf[:] = next
|
||||
return false
|
||||
end
|
Loading…
Reference in New Issue
Block a user