added GlobalMaxPool, GlobalMeanPool, and Flatten layers

This commit is contained in:
Garben Tanghe 2019-12-02 13:31:25 +01:00
parent 5a4f1932a6
commit 3e14bd878c
4 changed files with 86 additions and 6 deletions

View File

@ -14,10 +14,13 @@ These layers are used to build convolutional neural networks (CNNs).
```@docs
Conv
MaxPool
GlobalMaxPool
MeanPool
GlobalMeanPool
DepthwiseConv
ConvTranspose
CrossCor
Flatten
```
## Recurrent Layers

View File

@ -10,7 +10,8 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose,
GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, Flatten,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode!

View File

@ -95,8 +95,9 @@ outdims(l::Conv, isize) =
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Data should be stored in WHCN order (width, height, # channels, # batches).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
@ -171,8 +172,9 @@ Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Note that `out` must be an integer multiple of `in`.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Data should be stored in WHCN order (width, height, # channels, # batches).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
@ -304,6 +306,56 @@ end
outdims(l::CrossCor, isize) =
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
"""
GlobalMaxPool()
Global max pooling layer.
Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output,
by performing max pooling on the complete (w,h)-shaped feature maps.
"""
struct GlobalMaxPool end
function (g::GlobalMaxPool)(x)
# Input size
x_size = size(x)
# Kernel size
k = x_size[1:end-2]
# Pooling dimensions
pdims = PoolDims(x, k)
return maxpool(x, pdims)
end
function Base.show(io::IO, g::GlobalMaxPool)
print(io, "GlobalMaxPool()")
end
"""
GlobalMeanPool()
Global mean pooling layer.
Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output,
by performing mean pooling on the complete (w,h)-shaped feature maps.
"""
struct GlobalMeanPool end
function (g::GlobalMeanPool)(x)
# Input size
x_size = size(x)
# Kernel size
k = x_size[1:end-2]
# Pooling dimensions
pdims = PoolDims(x, k)
return meanpool(x, pdims)
end
function Base.show(io::IO, g::GlobalMeanPool)
print(io, "GlobalMeanPool()")
end
"""
MaxPool(k)
@ -363,4 +415,22 @@ function Base.show(io::IO, m::MeanPool)
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
end
outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
"""
Flatten()
Flattening layer.
Transforms (w,h,c,b)-shaped input into (w*h*c,b)-shaped output,
by linearizing all values for each element in the batch.
"""
struct Flatten end
function (f::Flatten)(x)
return reshape(x, :, size(x)[end])
end
function Base.show(io::IO, f::Flatten)
print(io, "Flatten()")
end

View File

@ -4,10 +4,16 @@ using Flux: gradient
@testset "Pooling" begin
x = randn(Float32, 10, 10, 3, 2)
gmp = GlobalMaxPool()
@test size(gmp(x)) == (1, 1, 3, 2)
gmp = GlobalMeanPool()
@test size(gmp(x)) == (1, 1, 3, 2)
mp = MaxPool((2, 2))
@test mp(x) == maxpool(x, PoolDims(x, 2))
mp = MeanPool((2, 2))
@test mp(x) == meanpool(x, PoolDims(x, 2))
f = Flatten()
@test size(f(x)) == (300, 2)
end
@testset "CNN" begin