diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 69958908..2b9b04e2 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -61,3 +61,29 @@ end x_hat = ConvTranspose((3, 3), 1 => 1)(y) @test size(x_hat) == size(x) end + +@testset "Conv with non quadratic window #700" begin + data = zeros(Float32, 7,7,1,1) + data[4,4,1,1] = 1 + + l = Conv((3,3), 1=>1) + expected = zeros(eltype(l.weight),5,5,1,1) + expected[2:end-1,2:end-1,1,1] = l.weight + @test expected == l(data) + + l = Conv((3,1), 1=>1) + expected = zeros(eltype(l.weight),5,7,1,1) + expected[2:end-1,4,1,1] = l.weight + @test expected == l(data) + + l = Conv((1,3), 1=>1) + expected = zeros(eltype(l.weight),7,5,1,1) + expected[4,2:end-1,1,1] = l.weight + @test expected == l(data) + + @test begin + # we test that the next expression does not throw + randn(Float32, 10,10,1,1) |> Conv((6,1), 1=>1, Flux.σ) + true + end +end