diff --git a/src/utils.jl b/src/utils.jl index 9bad3760..1bdfe456 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -10,8 +10,8 @@ zeros(dims...) = Base.zeros(Float32, dims...) unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) -stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...) -unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)] +stack(xs, dim) = cat(unsqueeze.(xs, dim)..., dims=dim) +unstack(xs, dim) = [copy(selectdim(xs, dim, i)) for i in 1:size(xs, dim)] """ chunk(xs, n)