Added tests for outdims

This commit is contained in:
Kyle Daruwalla 2019-12-05 22:54:25 -06:00
parent 31dda0ce6c
commit 6265b1fa39
4 changed files with 43 additions and 8 deletions

View File

@ -49,7 +49,7 @@ m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
outdims(m, (10, 10)) == (6, 6) outdims(m, (10, 10)) == (6, 6)
``` ```
""" """
outdims(c::Chain, isize) = foldl(, map(l -> (x -> outdims(l, x)), c.layers)) outdims(c::Chain, isize) = foldl(, map(l -> (x -> outdims(l, x)), c.layers))(isize)
# This is a temporary and naive implementation # This is a temporary and naive implementation
# it might be replaced in the future for better performance # it might be replaced in the future for better performance
@ -138,7 +138,7 @@ outdims(m, (5, 2)) == (5,)
outdims(m, (10,)) == (5,) outdims(m, (10,)) == (5,)
``` ```
""" """
outdims(l::Dense, isize) = (size(l.W)[2],) outdims(l::Dense, isize) = (size(l.W)[1],)
""" """
Diagonal(in::Integer) Diagonal(in::Integer)
@ -234,11 +234,11 @@ end
Calculate the output dimensions given the input dimensions, `isize`. Calculate the output dimensions given the input dimensions, `isize`.
```julia ```julia
m = Maxout(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) m = Maxout(() -> Conv((3, 3), 3 => 16), 2)
outdims(m, (10, 10)) == (8, 8) outdims(m, (10, 10)) == (8, 8)
``` ```
""" """
outdims(l::Maxout, isize) = outdims(first(l.over)) outdims(l::Maxout, isize) = outdims(first(l.over), isize)
""" """
SkipConnection(layers, connection) SkipConnection(layers, connection)

View File

@ -1,7 +1,7 @@
using NNlib: conv, ∇conv_data, depthwiseconv using NNlib: conv, ∇conv_data, depthwiseconv
_convoutdims(isize, ksize, ssize, pad) = Int.(floor.((isize .- ksize .+ 2 .* pad) ./ ssize .+ 1)) _convoutdims(isize, ksize, ssize, pad) = Int.(floor.((isize .- ksize .+ 2 .* pad) ./ ssize .+ 1))
_convtransoutdims(isize, ksize, ssize, pad) = Int.(ssize .* (isize .- 1) .+ ksize .- 2 .* pad)) _convtransoutdims(isize, ksize, ssize, pad) = Int.(ssize .* (isize .- 1) .+ ksize .- 2 .* pad)
expand(N, i::Tuple) = i expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N) expand(N, i::Integer) = ntuple(_ -> i, N)
@ -238,7 +238,7 @@ end
Calculate the output dimensions given the input dimensions, `isize`. Calculate the output dimensions given the input dimensions, `isize`.
```julia ```julia
m = DepthwiseConv((3, 3), 3 => 16) m = DepthwiseConv((3, 3), 3 => 6)
outdims(m, (10, 10)) == (8, 8) outdims(m, (10, 10)) == (8, 8)
``` ```
""" """
@ -366,7 +366,7 @@ m = MaxPool((2, 2))
outdims(m, (10, 10)) == (5, 5) outdims(m, (10, 10)) == (5, 5)
``` ```
""" """
outdims(l::MaxPool{N}, isize) where N = _convoutdims(isize, l.weight, l.stride, l.pad[1:N]) outdims(l::MaxPool{N}, isize) where N = _convoutdims(isize, l.k, l.stride, l.pad[1:N])
""" """
MeanPool(k) MeanPool(k)
@ -406,4 +406,4 @@ m = MeanPool((2, 2))
outdims(m, (10, 10)) == (5, 5) outdims(m, (10, 10)) == (5, 5)
``` ```
""" """
outdims(l::MeanPool{N}, isize) where N = _convoutdims(isize, l.weight, l.stride, l.pad[1:N]) outdims(l::MeanPool{N}, isize) where N = _convoutdims(isize, l.k, l.stride, l.pad[1:N])

View File

@ -92,4 +92,19 @@ import Flux: activations
@test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4) @test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4)
end end
end end
@testset "output dimensions" begin
m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
@test Flux.outdims(m, (10, 10)) == (6, 6)
m = Dense(10, 5)
@test Flux.outdims(m, (5, 2)) == (5,)
@test Flux.outdims(m, (10,)) == (5,)
m = Flux.Diagonal(10)
@test Flux.outdims(m, (10,)) == (10,)
m = Maxout(() -> Conv((3, 3), 3 => 16), 2)
@test Flux.outdims(m, (10, 10)) == (8, 8)
end
end end

View File

@ -107,3 +107,23 @@ end
true true
end end
end end
@testset "conv output dimensions" begin
m = Conv((3, 3), 3 => 16)
@test Flux.outdims(m, (10, 10)) == (8, 8)
m = ConvTranspose((3, 3), 3 => 16)
@test Flux.outdims(m, (8, 8)) == (10, 10)
m = DepthwiseConv((3, 3), 3 => 6)
@test Flux.outdims(m, (10, 10)) == (8, 8)
m = CrossCor((3, 3), 3 => 16)
@test Flux.outdims(m, (10, 10)) == (8, 8)
m = MaxPool((2, 2))
@test Flux.outdims(m, (10, 10)) == (5, 5)
m = MeanPool((2, 2))
@test Flux.outdims(m, (10, 10)) == (5, 5)
end