diff --git a/src/utils.jl b/src/utils.jl index 1be5ded5..f822c111 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,6 +9,21 @@ unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...) stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...) unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)] +""" + chunk(xs, n) + +Split `xs` into `n` parts. + +```julia +julia> chunk(1:10, 3) +3-element Array{Array{Int64,1},1}: + [1, 2, 3, 4] + [5, 6, 7, 8] + [9, 10] +``` +""" +chunk(xs, n) = collect(Iterators.partition(xs, ceil(Int, length(xs)/n))) + batchindex(xs, i) = (reverse(Base.tail(reverse(indices(xs))))..., i) """