Added tests for outdims
This commit is contained in:
parent
31dda0ce6c
commit
6265b1fa39
|
@ -49,7 +49,7 @@ m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
|
|||
outdims(m, (10, 10)) == (6, 6)
|
||||
```
|
||||
"""
|
||||
outdims(c::Chain, isize) = foldl(∘, map(l -> (x -> outdims(l, x)), c.layers))
|
||||
outdims(c::Chain, isize) = foldl(∘, map(l -> (x -> outdims(l, x)), c.layers))(isize)
|
||||
|
||||
# This is a temporary and naive implementation
|
||||
# it might be replaced in the future for better performance
|
||||
|
@ -138,7 +138,7 @@ outdims(m, (5, 2)) == (5,)
|
|||
outdims(m, (10,)) == (5,)
|
||||
```
|
||||
"""
|
||||
outdims(l::Dense, isize) = (size(l.W)[2],)
|
||||
outdims(l::Dense, isize) = (size(l.W)[1],)
|
||||
|
||||
"""
|
||||
Diagonal(in::Integer)
|
||||
|
@ -234,11 +234,11 @@ end
|
|||
Calculate the output dimensions given the input dimensions, `isize`.
|
||||
|
||||
```julia
|
||||
m = Maxout(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
|
||||
m = Maxout(() -> Conv((3, 3), 3 => 16), 2)
|
||||
outdims(m, (10, 10)) == (8, 8)
|
||||
```
|
||||
"""
|
||||
outdims(l::Maxout, isize) = outdims(first(l.over))
|
||||
outdims(l::Maxout, isize) = outdims(first(l.over), isize)
|
||||
|
||||
"""
|
||||
SkipConnection(layers, connection)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
using NNlib: conv, ∇conv_data, depthwiseconv
|
||||
|
||||
_convoutdims(isize, ksize, ssize, pad) = Int.(floor.((isize .- ksize .+ 2 .* pad) ./ ssize .+ 1))
|
||||
_convtransoutdims(isize, ksize, ssize, pad) = Int.(ssize .* (isize .- 1) .+ ksize .- 2 .* pad))
|
||||
_convtransoutdims(isize, ksize, ssize, pad) = Int.(ssize .* (isize .- 1) .+ ksize .- 2 .* pad)
|
||||
|
||||
expand(N, i::Tuple) = i
|
||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||
|
@ -238,7 +238,7 @@ end
|
|||
Calculate the output dimensions given the input dimensions, `isize`.
|
||||
|
||||
```julia
|
||||
m = DepthwiseConv((3, 3), 3 => 16)
|
||||
m = DepthwiseConv((3, 3), 3 => 6)
|
||||
outdims(m, (10, 10)) == (8, 8)
|
||||
```
|
||||
"""
|
||||
|
@ -366,7 +366,7 @@ m = MaxPool((2, 2))
|
|||
outdims(m, (10, 10)) == (5, 5)
|
||||
```
|
||||
"""
|
||||
outdims(l::MaxPool{N}, isize) where N = _convoutdims(isize, l.weight, l.stride, l.pad[1:N])
|
||||
outdims(l::MaxPool{N}, isize) where N = _convoutdims(isize, l.k, l.stride, l.pad[1:N])
|
||||
|
||||
"""
|
||||
MeanPool(k)
|
||||
|
@ -406,4 +406,4 @@ m = MeanPool((2, 2))
|
|||
outdims(m, (10, 10)) == (5, 5)
|
||||
```
|
||||
"""
|
||||
outdims(l::MeanPool{N}, isize) where N = _convoutdims(isize, l.weight, l.stride, l.pad[1:N])
|
||||
outdims(l::MeanPool{N}, isize) where N = _convoutdims(isize, l.k, l.stride, l.pad[1:N])
|
|
@ -92,4 +92,19 @@ import Flux: activations
|
|||
@test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4)
|
||||
end
|
||||
end
|
||||
|
||||
@testset "output dimensions" begin
|
||||
m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
|
||||
@test Flux.outdims(m, (10, 10)) == (6, 6)
|
||||
|
||||
m = Dense(10, 5)
|
||||
@test Flux.outdims(m, (5, 2)) == (5,)
|
||||
@test Flux.outdims(m, (10,)) == (5,)
|
||||
|
||||
m = Flux.Diagonal(10)
|
||||
@test Flux.outdims(m, (10,)) == (10,)
|
||||
|
||||
m = Maxout(() -> Conv((3, 3), 3 => 16), 2)
|
||||
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -107,3 +107,23 @@ end
|
|||
true
|
||||
end
|
||||
end
|
||||
|
||||
@testset "conv output dimensions" begin
|
||||
m = Conv((3, 3), 3 => 16)
|
||||
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
||||
|
||||
m = ConvTranspose((3, 3), 3 => 16)
|
||||
@test Flux.outdims(m, (8, 8)) == (10, 10)
|
||||
|
||||
m = DepthwiseConv((3, 3), 3 => 6)
|
||||
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
||||
|
||||
m = CrossCor((3, 3), 3 => 16)
|
||||
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
||||
|
||||
m = MaxPool((2, 2))
|
||||
@test Flux.outdims(m, (10, 10)) == (5, 5)
|
||||
|
||||
m = MeanPool((2, 2))
|
||||
@test Flux.outdims(m, (10, 10)) == (5, 5)
|
||||
end
|
Loading…
Reference in New Issue