diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index d83fc462..6e8d0b76 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -219,3 +219,24 @@ Flux.@functor Affine ``` This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md). + +## Utility functions + +Flux provides some utility functions to help you generate models in an automated fashion. + +`outdims` enables you to calculate the spatial output dimensions of layers like `Conv` when applied to input images of a given size. +Currently limited to the following layers: +- `Chain` +- `Dense` +- `Conv` +- `Diagonal` +- `Maxout` +- `ConvTranspose` +- `DepthwiseConv` +- `CrossCor` +- `MaxPool` +- `MeanPool` + +```@docs +outdims +``` diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 2a465208..6f056429 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -39,6 +39,17 @@ function Base.show(io::IO, c::Chain) print(io, ")") end +""" + outdims(c::Chain, isize) + +Calculate the output dimensions given the input dimensions, `isize`. + +```julia +m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) +outdims(m, (10, 10)) == (6, 6) +``` +""" +outdims(c::Chain, isize) = foldl(∘, map(l -> (x -> outdims(l, x)), c.layers))(isize) # This is a temporary and naive implementation # it might be replaced in the future for better performance @@ -116,6 +127,19 @@ end (a::Dense{<:Any,W})(x::AbstractArray{<:AbstractFloat}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = a(T.(x)) +""" + outdims(l::Dense, isize) + +Calculate the output dimensions given the input dimensions, `isize`. + +```julia +m = Dense(10, 5) +outdims(m, (5, 2)) == (5,) +outdims(m, (10,)) == (5,) +``` +""" +outdims(l::Dense, isize) = (size(l.W)[1],) + """ Diagonal(in::Integer) @@ -145,6 +169,7 @@ function Base.show(io::IO, l::Diagonal) print(io, "Diagonal(", length(l.α), ")") end +outdims(l::Diagonal, isize) = (length(l.α),) """ Maxout(over) @@ -193,6 +218,8 @@ function (mo::Maxout)(input::AbstractArray) mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) end +outdims(l::Maxout, isize) = outdims(first(l.over), isize) + """ SkipConnection(layers, connection) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 829051ae..ef167f71 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,4 +1,9 @@ -using NNlib: conv, ∇conv_data, depthwiseconv +using NNlib: conv, ∇conv_data, depthwiseconv, output_size + +# pad dims of x with dims of y until ndims(x) == ndims(y) +_paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]...) + +_convtransoutdims(isize, ksize, ssize, dsize, pad) = (isize .- 1).*ssize .+ 1 .+ (ksize .- 1).*dsize .- (pad[1:2:end] .+ pad[2:2:end]) expand(N, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) @@ -68,6 +73,21 @@ end (a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = a(T.(x)) +""" + outdims(l::Conv, isize::Tuple) + +Calculate the output dimensions given the input dimensions, `isize`. +Batch size and channel size are ignored as per `NNlib.jl`. + +```julia +m = Conv((3, 3), 3 => 16) +outdims(m, (10, 10)) == (8, 8) +outdims(m, (10, 10, 1, 3)) == (8, 8) +``` +""" +outdims(l::Conv, isize) = + output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation)) + """ ConvTranspose(size, in=>out) ConvTranspose(size, in=>out, relu) @@ -140,6 +160,9 @@ end (a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = a(T.(x)) + +outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad) + """ DepthwiseConv(size, in=>out) DepthwiseConv(size, in=>out, relu) @@ -204,6 +227,9 @@ end (a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = a(T.(x)) +outdims(l::DepthwiseConv, isize) = + output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation)) + """ CrossCor(size, in=>out) CrossCor(size, in=>out, relu) @@ -275,6 +301,9 @@ end (a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = a(T.(x)) +outdims(l::CrossCor, isize) = + output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation)) + """ MaxPool(k) @@ -304,6 +333,8 @@ function Base.show(io::IO, m::MaxPool) print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")") end +outdims(l::MaxPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad)) + """ MeanPool(k) @@ -331,3 +362,5 @@ end function Base.show(io::IO, m::MeanPool) print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")") end + +outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad)) \ No newline at end of file diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 0ff1776d..421c7721 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -92,4 +92,19 @@ import Flux: activations @test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4) end end + + @testset "output dimensions" begin + m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) + @test Flux.outdims(m, (10, 10)) == (6, 6) + + m = Dense(10, 5) + @test Flux.outdims(m, (5, 2)) == (5,) + @test Flux.outdims(m, (10,)) == (5,) + + m = Flux.Diagonal(10) + @test Flux.outdims(m, (10,)) == (10,) + + m = Maxout(() -> Conv((3, 3), 3 => 16), 2) + @test Flux.outdims(m, (10, 10)) == (8, 8) + end end diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 1e7ede01..03a0d1a4 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -107,3 +107,55 @@ end true end end + +@testset "conv output dimensions" begin + m = Conv((3, 3), 3 => 16) + @test Flux.outdims(m, (10, 10)) == (8, 8) + m = Conv((3, 3), 3 => 16; stride = 2) + @test Flux.outdims(m, (5, 5)) == (2, 2) + m = Conv((3, 3), 3 => 16; stride = 2, pad = 3) + @test Flux.outdims(m, (5, 5)) == (5, 5) + m = Conv((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) + @test Flux.outdims(m, (5, 5)) == (4, 4) + + m = ConvTranspose((3, 3), 3 => 16) + @test Flux.outdims(m, (8, 8)) == (10, 10) + m = ConvTranspose((3, 3), 3 => 16; stride = 2) + @test Flux.outdims(m, (2, 2)) == (5, 5) + m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3) + @test Flux.outdims(m, (5, 5)) == (5, 5) + m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) + @test Flux.outdims(m, (4, 4)) == (5, 5) + + m = DepthwiseConv((3, 3), 3 => 6) + @test Flux.outdims(m, (10, 10)) == (8, 8) + m = DepthwiseConv((3, 3), 3 => 6; stride = 2) + @test Flux.outdims(m, (5, 5)) == (2, 2) + m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3) + @test Flux.outdims(m, (5, 5)) == (5, 5) + m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3, dilation = 2) + @test Flux.outdims(m, (5, 5)) == (4, 4) + + m = CrossCor((3, 3), 3 => 16) + @test Flux.outdims(m, (10, 10)) == (8, 8) + m = CrossCor((3, 3), 3 => 16; stride = 2) + @test Flux.outdims(m, (5, 5)) == (2, 2) + m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3) + @test Flux.outdims(m, (5, 5)) == (5, 5) + m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) + @test Flux.outdims(m, (5, 5)) == (4, 4) + + m = MaxPool((2, 2)) + @test Flux.outdims(m, (10, 10)) == (5, 5) + m = MaxPool((2, 2); stride = 1) + @test Flux.outdims(m, (5, 5)) == (4, 4) + m = MaxPool((2, 2); stride = 2, pad = 3) + @test Flux.outdims(m, (5, 5)) == (5, 5) + + m = MeanPool((2, 2)) + @test Flux.outdims(m, (10, 10)) == (5, 5) + m = MeanPool((2, 2); stride = 1) + @test Flux.outdims(m, (5, 5)) == (4, 4) + m = MeanPool((2, 2); stride = 2, pad = 3) + @test Flux.outdims(m, (5, 5)) == (5, 5) +end \ No newline at end of file