diff --git a/src/dims/utils.jl b/src/dims/utils.jl index 680c85a1..7cc00476 100644 --- a/src/dims/utils.jl +++ b/src/dims/utils.jl @@ -1,4 +1,5 @@ -unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) +unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) +Base.squeeze(xs) = squeeze(xs, 1) stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...) unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]