diff --git a/src/dims/batching.jl b/src/dims/batching.jl index a3ae6ed8..df7f81f6 100644 --- a/src/dims/batching.jl +++ b/src/dims/batching.jl @@ -27,3 +27,6 @@ end convertel(T::Type, xs::Batch) = eltype(eltype(xs)) isa T ? xs : Batch(map(x->convertel(T, x), xs)) + +batchone(x) = Batch((x,)) +batchone(x::Batch) = x