diff --git a/src/dims/batching.jl b/src/dims/batching.jl index 24dfc85d..fa8a4ce2 100644 --- a/src/dims/batching.jl +++ b/src/dims/batching.jl @@ -36,3 +36,7 @@ function rebatch(xs) B = Array{eltype(xs),dims+1} Batch{T,B}(xs) end + +convertel(T::Type, xs::Batch) = + isa(eltype(eltype(xs)), T) ? xs : + Batch(map(x->convertel(T, x), xs))