diff --git a/src/dims/batching.jl b/src/dims/batching.jl index d1a99e4d..d204bd18 100644 --- a/src/dims/batching.jl +++ b/src/dims/batching.jl @@ -32,7 +32,7 @@ convertel(T::Type, xs::Batch) = # TODO: remove this in favour of full batching semantics mapt(f, x) = f(x) -mapt(f, xs::Tuple) = map(f, xs) +mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs) batchone(x) = Batch((x,)) batchone(x::Batch) = x