diff --git a/test/utils.jl b/test/utils.jl index 2b4692a9..7bcf72c3 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -87,6 +87,12 @@ end @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)] end +@testset "Basic Stacking" begin + x = randn(3,3) + stacked = stack([x, x], 2) + @test size(stacked) == (3,2,3) +end + @testset "Precision" begin m = Chain(Dense(10, 5, relu), Dense(5, 2)) x = rand(10)