From 6265b1fa39c5d7d289ccd5a00c94ae9f448377fc Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Thu, 5 Dec 2019 22:54:25 -0600 Subject: [PATCH] Added tests for outdims --- src/layers/basic.jl | 8 ++++---- src/layers/conv.jl | 8 ++++---- test/layers/basic.jl | 15 +++++++++++++++ test/layers/conv.jl | 20 ++++++++++++++++++++ 4 files changed, 43 insertions(+), 8 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 8794b58c..b62d8bb9 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -49,7 +49,7 @@ m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) outdims(m, (10, 10)) == (6, 6) ``` """ -outdims(c::Chain, isize) = foldl(∘, map(l -> (x -> outdims(l, x)), c.layers)) +outdims(c::Chain, isize) = foldl(∘, map(l -> (x -> outdims(l, x)), c.layers))(isize) # This is a temporary and naive implementation # it might be replaced in the future for better performance @@ -138,7 +138,7 @@ outdims(m, (5, 2)) == (5,) outdims(m, (10,)) == (5,) ``` """ -outdims(l::Dense, isize) = (size(l.W)[2],) +outdims(l::Dense, isize) = (size(l.W)[1],) """ Diagonal(in::Integer) @@ -234,11 +234,11 @@ end Calculate the output dimensions given the input dimensions, `isize`. ```julia -m = Maxout(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) +m = Maxout(() -> Conv((3, 3), 3 => 16), 2) outdims(m, (10, 10)) == (8, 8) ``` """ -outdims(l::Maxout, isize) = outdims(first(l.over)) +outdims(l::Maxout, isize) = outdims(first(l.over), isize) """ SkipConnection(layers, connection) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 2e3e87d7..6ce9bcbf 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,7 +1,7 @@ using NNlib: conv, ∇conv_data, depthwiseconv _convoutdims(isize, ksize, ssize, pad) = Int.(floor.((isize .- ksize .+ 2 .* pad) ./ ssize .+ 1)) -_convtransoutdims(isize, ksize, ssize, pad) = Int.(ssize .* (isize .- 1) .+ ksize .- 2 .* pad)) +_convtransoutdims(isize, ksize, ssize, pad) = Int.(ssize .* (isize .- 1) .+ ksize .- 2 .* pad) expand(N, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) @@ -238,7 +238,7 @@ end Calculate the output dimensions given the input dimensions, `isize`. ```julia -m = DepthwiseConv((3, 3), 3 => 16) +m = DepthwiseConv((3, 3), 3 => 6) outdims(m, (10, 10)) == (8, 8) ``` """ @@ -366,7 +366,7 @@ m = MaxPool((2, 2)) outdims(m, (10, 10)) == (5, 5) ``` """ -outdims(l::MaxPool{N}, isize) where N = _convoutdims(isize, l.weight, l.stride, l.pad[1:N]) +outdims(l::MaxPool{N}, isize) where N = _convoutdims(isize, l.k, l.stride, l.pad[1:N]) """ MeanPool(k) @@ -406,4 +406,4 @@ m = MeanPool((2, 2)) outdims(m, (10, 10)) == (5, 5) ``` """ -outdims(l::MeanPool{N}, isize) where N = _convoutdims(isize, l.weight, l.stride, l.pad[1:N]) \ No newline at end of file +outdims(l::MeanPool{N}, isize) where N = _convoutdims(isize, l.k, l.stride, l.pad[1:N]) \ No newline at end of file diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 0ff1776d..421c7721 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -92,4 +92,19 @@ import Flux: activations @test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4) end end + + @testset "output dimensions" begin + m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) + @test Flux.outdims(m, (10, 10)) == (6, 6) + + m = Dense(10, 5) + @test Flux.outdims(m, (5, 2)) == (5,) + @test Flux.outdims(m, (10,)) == (5,) + + m = Flux.Diagonal(10) + @test Flux.outdims(m, (10,)) == (10,) + + m = Maxout(() -> Conv((3, 3), 3 => 16), 2) + @test Flux.outdims(m, (10, 10)) == (8, 8) + end end diff --git a/test/layers/conv.jl b/test/layers/conv.jl index b4136062..5701df80 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -107,3 +107,23 @@ end true end end + +@testset "conv output dimensions" begin + m = Conv((3, 3), 3 => 16) + @test Flux.outdims(m, (10, 10)) == (8, 8) + + m = ConvTranspose((3, 3), 3 => 16) + @test Flux.outdims(m, (8, 8)) == (10, 10) + + m = DepthwiseConv((3, 3), 3 => 6) + @test Flux.outdims(m, (10, 10)) == (8, 8) + + m = CrossCor((3, 3), 3 => 16) + @test Flux.outdims(m, (10, 10)) == (8, 8) + + m = MaxPool((2, 2)) + @test Flux.outdims(m, (10, 10)) == (5, 5) + + m = MeanPool((2, 2)) + @test Flux.outdims(m, (10, 10)) == (5, 5) +end \ No newline at end of file