Added tests for outdims

This commit is contained in:
Kyle Daruwalla 2019-12-05 22:54:25 -06:00
parent 31dda0ce6c
commit 6265b1fa39
4 changed files with 43 additions and 8 deletions

View File

@ -49,7 +49,7 @@ m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
outdims(m, (10, 10)) == (6, 6)
```
"""
outdims(c::Chain, isize) = foldl(, map(l -> (x -> outdims(l, x)), c.layers))
outdims(c::Chain, isize) = foldl(, map(l -> (x -> outdims(l, x)), c.layers))(isize)
# This is a temporary and naive implementation
# it might be replaced in the future for better performance
@ -138,7 +138,7 @@ outdims(m, (5, 2)) == (5,)
outdims(m, (10,)) == (5,)
```
"""
outdims(l::Dense, isize) = (size(l.W)[2],)
outdims(l::Dense, isize) = (size(l.W)[1],)
"""
Diagonal(in::Integer)
@ -234,11 +234,11 @@ end
Calculate the output dimensions given the input dimensions, `isize`.
```julia
m = Maxout(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
m = Maxout(() -> Conv((3, 3), 3 => 16), 2)
outdims(m, (10, 10)) == (8, 8)
```
"""
outdims(l::Maxout, isize) = outdims(first(l.over))
outdims(l::Maxout, isize) = outdims(first(l.over), isize)
"""
SkipConnection(layers, connection)

View File

@ -1,7 +1,7 @@
using NNlib: conv, ∇conv_data, depthwiseconv
_convoutdims(isize, ksize, ssize, pad) = Int.(floor.((isize .- ksize .+ 2 .* pad) ./ ssize .+ 1))
_convtransoutdims(isize, ksize, ssize, pad) = Int.(ssize .* (isize .- 1) .+ ksize .- 2 .* pad))
_convtransoutdims(isize, ksize, ssize, pad) = Int.(ssize .* (isize .- 1) .+ ksize .- 2 .* pad)
expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)
@ -238,7 +238,7 @@ end
Calculate the output dimensions given the input dimensions, `isize`.
```julia
m = DepthwiseConv((3, 3), 3 => 16)
m = DepthwiseConv((3, 3), 3 => 6)
outdims(m, (10, 10)) == (8, 8)
```
"""
@ -366,7 +366,7 @@ m = MaxPool((2, 2))
outdims(m, (10, 10)) == (5, 5)
```
"""
outdims(l::MaxPool{N}, isize) where N = _convoutdims(isize, l.weight, l.stride, l.pad[1:N])
outdims(l::MaxPool{N}, isize) where N = _convoutdims(isize, l.k, l.stride, l.pad[1:N])
"""
MeanPool(k)
@ -406,4 +406,4 @@ m = MeanPool((2, 2))
outdims(m, (10, 10)) == (5, 5)
```
"""
outdims(l::MeanPool{N}, isize) where N = _convoutdims(isize, l.weight, l.stride, l.pad[1:N])
outdims(l::MeanPool{N}, isize) where N = _convoutdims(isize, l.k, l.stride, l.pad[1:N])

View File

@ -92,4 +92,19 @@ import Flux: activations
@test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4)
end
end
@testset "output dimensions" begin
m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
@test Flux.outdims(m, (10, 10)) == (6, 6)
m = Dense(10, 5)
@test Flux.outdims(m, (5, 2)) == (5,)
@test Flux.outdims(m, (10,)) == (5,)
m = Flux.Diagonal(10)
@test Flux.outdims(m, (10,)) == (10,)
m = Maxout(() -> Conv((3, 3), 3 => 16), 2)
@test Flux.outdims(m, (10, 10)) == (8, 8)
end
end

View File

@ -107,3 +107,23 @@ end
true
end
end
@testset "conv output dimensions" begin
m = Conv((3, 3), 3 => 16)
@test Flux.outdims(m, (10, 10)) == (8, 8)
m = ConvTranspose((3, 3), 3 => 16)
@test Flux.outdims(m, (8, 8)) == (10, 10)
m = DepthwiseConv((3, 3), 3 => 6)
@test Flux.outdims(m, (10, 10)) == (8, 8)
m = CrossCor((3, 3), 3 => 16)
@test Flux.outdims(m, (10, 10)) == (8, 8)
m = MaxPool((2, 2))
@test Flux.outdims(m, (10, 10)) == (5, 5)
m = MeanPool((2, 2))
@test Flux.outdims(m, (10, 10)) == (5, 5)
end