From 9b897fc60139a56825804068c7b73745371069c6 Mon Sep 17 00:00:00 2001 From: kolia Date: Thu, 20 Dec 2018 10:03:21 -0500 Subject: [PATCH 1/2] Tiny bugfix: `stack` was still calling julia 0.6 `cat` Also added tiny test for good measure. --- src/utils.jl | 2 +- test/utils.jl | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 9bad3760..2e3d5c4b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -10,7 +10,7 @@ zeros(dims...) = Base.zeros(Float32, dims...) unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) -stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...) +stack(xs, dim) = cat(unsqueeze.(xs, dim)...; dims=dim) unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)] """ diff --git a/test/utils.jl b/test/utils.jl index af0d50fe..d07f88c9 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,5 +1,5 @@ using Flux -using Flux: throttle, jacobian, glorot_uniform, glorot_normal +using Flux: throttle, jacobian, glorot_uniform, glorot_normal, stack using StatsBase: std using Random using Test @@ -86,3 +86,9 @@ end m = RNN(10, 5) @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)] end + +@testset "Basic" begin + x = randn(3,3) + stacked = stack([x, x], 2) + @test size(stacked) == (3,2,3) +end From eb9da4084ffc55858398dd8b0bb1adaefb261a38 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 15 Feb 2019 20:33:21 +0530 Subject: [PATCH 2/2] remove spurious line change --- test/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils.jl b/test/utils.jl index 3887fb3d..7bcf72c3 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -110,4 +110,4 @@ end @test unstack(stacked_array, 2) == unstacked_array @test stack(unstacked_array, 2) == stacked_array @test stack(unstack(stacked_array, 1), 1) == stacked_array -end \ No newline at end of file +end