Merge #960
960: Added utility function outdims to compute output dimensions of a layer r=dhairyagandhi96 a=darsnack Based on Slack chatter, I added a utility function, `outdims`, that computes the output dimensions for given input dimensions. Example ```julia layer = Conv((3, 3), 3 => 16) outdims(layer, (10, 10)) # returns (8, 8) ``` Co-authored-by: Kyle Daruwalla <daruwalla@wisc.edu>
This commit is contained in:
commit
55616afc11
|
@ -219,3 +219,24 @@ Flux.@functor Affine
|
|||
```
|
||||
|
||||
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
|
||||
|
||||
## Utility functions
|
||||
|
||||
Flux provides some utility functions to help you generate models in an automated fashion.
|
||||
|
||||
`outdims` enables you to calculate the spatial output dimensions of layers like `Conv` when applied to input images of a given size.
|
||||
Currently limited to the following layers:
|
||||
- `Chain`
|
||||
- `Dense`
|
||||
- `Conv`
|
||||
- `Diagonal`
|
||||
- `Maxout`
|
||||
- `ConvTranspose`
|
||||
- `DepthwiseConv`
|
||||
- `CrossCor`
|
||||
- `MaxPool`
|
||||
- `MeanPool`
|
||||
|
||||
```@docs
|
||||
outdims
|
||||
```
|
||||
|
|
|
@ -39,6 +39,17 @@ function Base.show(io::IO, c::Chain)
|
|||
print(io, ")")
|
||||
end
|
||||
|
||||
"""
|
||||
outdims(c::Chain, isize)
|
||||
|
||||
Calculate the output dimensions given the input dimensions, `isize`.
|
||||
|
||||
```julia
|
||||
m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
|
||||
outdims(m, (10, 10)) == (6, 6)
|
||||
```
|
||||
"""
|
||||
outdims(c::Chain, isize) = foldl(∘, map(l -> (x -> outdims(l, x)), c.layers))(isize)
|
||||
|
||||
# This is a temporary and naive implementation
|
||||
# it might be replaced in the future for better performance
|
||||
|
@ -116,6 +127,19 @@ end
|
|||
(a::Dense{<:Any,W})(x::AbstractArray{<:AbstractFloat}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
||||
"""
|
||||
outdims(l::Dense, isize)
|
||||
|
||||
Calculate the output dimensions given the input dimensions, `isize`.
|
||||
|
||||
```julia
|
||||
m = Dense(10, 5)
|
||||
outdims(m, (5, 2)) == (5,)
|
||||
outdims(m, (10,)) == (5,)
|
||||
```
|
||||
"""
|
||||
outdims(l::Dense, isize) = (size(l.W)[1],)
|
||||
|
||||
"""
|
||||
Diagonal(in::Integer)
|
||||
|
||||
|
@ -145,6 +169,7 @@ function Base.show(io::IO, l::Diagonal)
|
|||
print(io, "Diagonal(", length(l.α), ")")
|
||||
end
|
||||
|
||||
outdims(l::Diagonal, isize) = (length(l.α),)
|
||||
|
||||
"""
|
||||
Maxout(over)
|
||||
|
@ -193,6 +218,8 @@ function (mo::Maxout)(input::AbstractArray)
|
|||
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
|
||||
end
|
||||
|
||||
outdims(l::Maxout, isize) = outdims(first(l.over), isize)
|
||||
|
||||
"""
|
||||
SkipConnection(layers, connection)
|
||||
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
using NNlib: conv, ∇conv_data, depthwiseconv
|
||||
using NNlib: conv, ∇conv_data, depthwiseconv, output_size
|
||||
|
||||
# pad dims of x with dims of y until ndims(x) == ndims(y)
|
||||
_paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]...)
|
||||
|
||||
_convtransoutdims(isize, ksize, ssize, dsize, pad) = (isize .- 1).*ssize .+ 1 .+ (ksize .- 1).*dsize .- (pad[1:2:end] .+ pad[2:2:end])
|
||||
|
||||
expand(N, i::Tuple) = i
|
||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||
|
@ -68,6 +73,21 @@ end
|
|||
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
||||
"""
|
||||
outdims(l::Conv, isize::Tuple)
|
||||
|
||||
Calculate the output dimensions given the input dimensions, `isize`.
|
||||
Batch size and channel size are ignored as per `NNlib.jl`.
|
||||
|
||||
```julia
|
||||
m = Conv((3, 3), 3 => 16)
|
||||
outdims(m, (10, 10)) == (8, 8)
|
||||
outdims(m, (10, 10, 1, 3)) == (8, 8)
|
||||
```
|
||||
"""
|
||||
outdims(l::Conv, isize) =
|
||||
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
|
||||
|
||||
"""
|
||||
ConvTranspose(size, in=>out)
|
||||
ConvTranspose(size, in=>out, relu)
|
||||
|
@ -140,6 +160,9 @@ end
|
|||
|
||||
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
||||
outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad)
|
||||
|
||||
"""
|
||||
DepthwiseConv(size, in=>out)
|
||||
DepthwiseConv(size, in=>out, relu)
|
||||
|
@ -204,6 +227,9 @@ end
|
|||
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
||||
outdims(l::DepthwiseConv, isize) =
|
||||
output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
|
||||
|
||||
"""
|
||||
CrossCor(size, in=>out)
|
||||
CrossCor(size, in=>out, relu)
|
||||
|
@ -275,6 +301,9 @@ end
|
|||
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
||||
outdims(l::CrossCor, isize) =
|
||||
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
|
||||
|
||||
"""
|
||||
MaxPool(k)
|
||||
|
||||
|
@ -304,6 +333,8 @@ function Base.show(io::IO, m::MaxPool)
|
|||
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||
end
|
||||
|
||||
outdims(l::MaxPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
|
||||
|
||||
"""
|
||||
MeanPool(k)
|
||||
|
||||
|
@ -331,3 +362,5 @@ end
|
|||
function Base.show(io::IO, m::MeanPool)
|
||||
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||
end
|
||||
|
||||
outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
|
|
@ -92,4 +92,19 @@ import Flux: activations
|
|||
@test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4)
|
||||
end
|
||||
end
|
||||
|
||||
@testset "output dimensions" begin
|
||||
m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
|
||||
@test Flux.outdims(m, (10, 10)) == (6, 6)
|
||||
|
||||
m = Dense(10, 5)
|
||||
@test Flux.outdims(m, (5, 2)) == (5,)
|
||||
@test Flux.outdims(m, (10,)) == (5,)
|
||||
|
||||
m = Flux.Diagonal(10)
|
||||
@test Flux.outdims(m, (10,)) == (10,)
|
||||
|
||||
m = Maxout(() -> Conv((3, 3), 3 => 16), 2)
|
||||
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -107,3 +107,55 @@ end
|
|||
true
|
||||
end
|
||||
end
|
||||
|
||||
@testset "conv output dimensions" begin
|
||||
m = Conv((3, 3), 3 => 16)
|
||||
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
||||
m = Conv((3, 3), 3 => 16; stride = 2)
|
||||
@test Flux.outdims(m, (5, 5)) == (2, 2)
|
||||
m = Conv((3, 3), 3 => 16; stride = 2, pad = 3)
|
||||
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||
m = Conv((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2)
|
||||
@test Flux.outdims(m, (5, 5)) == (4, 4)
|
||||
|
||||
m = ConvTranspose((3, 3), 3 => 16)
|
||||
@test Flux.outdims(m, (8, 8)) == (10, 10)
|
||||
m = ConvTranspose((3, 3), 3 => 16; stride = 2)
|
||||
@test Flux.outdims(m, (2, 2)) == (5, 5)
|
||||
m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3)
|
||||
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||
m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2)
|
||||
@test Flux.outdims(m, (4, 4)) == (5, 5)
|
||||
|
||||
m = DepthwiseConv((3, 3), 3 => 6)
|
||||
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
||||
m = DepthwiseConv((3, 3), 3 => 6; stride = 2)
|
||||
@test Flux.outdims(m, (5, 5)) == (2, 2)
|
||||
m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3)
|
||||
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||
m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3, dilation = 2)
|
||||
@test Flux.outdims(m, (5, 5)) == (4, 4)
|
||||
|
||||
m = CrossCor((3, 3), 3 => 16)
|
||||
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
||||
m = CrossCor((3, 3), 3 => 16; stride = 2)
|
||||
@test Flux.outdims(m, (5, 5)) == (2, 2)
|
||||
m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3)
|
||||
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||
m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2)
|
||||
@test Flux.outdims(m, (5, 5)) == (4, 4)
|
||||
|
||||
m = MaxPool((2, 2))
|
||||
@test Flux.outdims(m, (10, 10)) == (5, 5)
|
||||
m = MaxPool((2, 2); stride = 1)
|
||||
@test Flux.outdims(m, (5, 5)) == (4, 4)
|
||||
m = MaxPool((2, 2); stride = 2, pad = 3)
|
||||
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||
|
||||
m = MeanPool((2, 2))
|
||||
@test Flux.outdims(m, (10, 10)) == (5, 5)
|
||||
m = MeanPool((2, 2); stride = 1)
|
||||
@test Flux.outdims(m, (5, 5)) == (4, 4)
|
||||
m = MeanPool((2, 2); stride = 2, pad = 3)
|
||||
@test Flux.outdims(m, (5, 5)) == (5, 5)
|
||||
end
|
Loading…
Reference in New Issue