diff --git a/test/utils.jl b/test/utils.jl index c60645c6..2b4692a9 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,5 +1,5 @@ using Flux -using Flux: throttle, jacobian, glorot_uniform, glorot_normal +using Flux: throttle, jacobian, glorot_uniform, glorot_normal, stack, unstack using StatsBase: std using Random using Test @@ -97,3 +97,11 @@ end @test eltype(f32(f64(m))[1].W.data) == Float32 @test Tracker.isleaf(f32(f64(m))[1].W) end + +@testset "Stacking" begin + stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ] + unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]] + @test unstack(stacked_array, 2) == unstacked_array + @test stack(unstacked_array, 2) == stacked_array + @test stack(unstack(stacked_array, 1), 1) == stacked_array +end