nested batch tweaks

This commit is contained in:
Mike J Innes 2017-06-09 18:55:21 +01:00
parent 9c8dbb6b4b
commit 65400f20ab
2 changed files with 9 additions and 3 deletions

View File

@ -59,6 +59,7 @@ end
rawbatch(xs::Batchable) = rawbatch(storage(xs))
size(xs::Batchable) = size(storage(xs))
getindex(xs::Batchable, i) = getindex(storage(xs), i)
setindex!(xs::Batchable, v, i...) = setindex!(storage(xs), v, i...)
Base.vcat{T<:Batchable}(xs::T, ys::T)::T = vcat(rawbatch(xs), rawbatch(ys))
@ -84,7 +85,8 @@ dimdec(T::Type{<:AbstractArray}) = deparam(T){eltype(T),ndims(T)-1}
btype(B::Type, S::Type{<:AbstractArray}) = B
btype(B::Type{<:Batchable}, S::Type{<:AbstractArray}) = B{dimdec(S),S}
btype(B::Type{<:Batchable{T}} where T, S::Type{<:AbstractArray}) = B{S}
btype{T}(B::Type{<:Batchable{T}}, S::Type{<:AbstractArray}) = B{S}
btype{T,S<:AbstractArray}(B::Type{<:Batchable{T,S}}, ::Type{S}) = B
btype(B::Type{<:Batchable{<:Batchable}}, S::Type{<:AbstractArray}) =
deparam(B){btype(eltype(B), dimdec(S)),S}

View File

@ -48,7 +48,9 @@ end
function Batched(itr, n::Integer)
n >= 1 || throw(ArgumentError("batch size must be >= 1"))
itr = StatefulIter(itr)
buf = convert(Batch, similar(eltype(itr)(), n, size(peek(itr))...))
x = peek(itr)
buf = convert(Batch{typeof(peek(itr))},
similar(rawbatch(x), n, size(rawbatch(x))...))
Batched(itr, buf)
end
@ -65,6 +67,8 @@ next(x::Batched, _) = x.buf, ()
function done(x::Batched, _)
next = taken!(x.itr, length(x.buf))
length(next) < length(x.buf) && return true
x.buf[:] = next
for (i, n) in enumerate(next)
x.buf[i] = rawbatch(n)
end
return false
end