950: added GlobalMaxPool, GlobalMeanPool, and flatten layers r=CarloLucibello a=gartangh



Co-authored-by: Garben Tanghe <garben.tanghe@gmail.com>
This commit is contained in:
bors[bot] 2020-03-08 14:27:10 +00:00 committed by GitHub
commit d4cf1436df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 92 additions and 15 deletions

View File

@ -14,10 +14,13 @@ These layers are used to build convolutional neural networks (CNNs).
```@docs ```@docs
Conv Conv
MaxPool MaxPool
GlobalMaxPool
MeanPool MeanPool
GlobalMeanPool
DepthwiseConv DepthwiseConv
ConvTranspose ConvTranspose
CrossCor CrossCor
flatten
``` ```
## Recurrent Layers ## Recurrent Layers

View File

@ -10,7 +10,8 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose,
GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, flatten,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode! SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode!

View File

@ -95,8 +95,9 @@ outdims(l::Conv, isize) =
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`. Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively. `in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would Data should be stored in WHCN order (width, height, # channels, # batches).
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`. Takes the keyword arguments `pad`, `stride` and `dilation`.
""" """
@ -171,8 +172,9 @@ Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively. `in` and `out` specify the number of input and output channels respectively.
Note that `out` must be an integer multiple of `in`. Note that `out` must be an integer multiple of `in`.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would Data should be stored in WHCN order (width, height, # channels, # batches).
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`. Takes the keyword arguments `pad`, `stride` and `dilation`.
""" """
@ -304,6 +306,56 @@ end
outdims(l::CrossCor, isize) = outdims(l::CrossCor, isize) =
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation)) output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
"""
GlobalMaxPool()
Global max pooling layer.
Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output,
by performing max pooling on the complete (w,h)-shaped feature maps.
"""
struct GlobalMaxPool end
function (g::GlobalMaxPool)(x)
# Input size
x_size = size(x)
# Kernel size
k = x_size[1:end-2]
# Pooling dimensions
pdims = PoolDims(x, k)
return maxpool(x, pdims)
end
function Base.show(io::IO, g::GlobalMaxPool)
print(io, "GlobalMaxPool()")
end
"""
GlobalMeanPool()
Global mean pooling layer.
Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output,
by performing mean pooling on the complete (w,h)-shaped feature maps.
"""
struct GlobalMeanPool end
function (g::GlobalMeanPool)(x)
# Input size
x_size = size(x)
# Kernel size
k = x_size[1:end-2]
# Pooling dimensions
pdims = PoolDims(x, k)
return meanpool(x, pdims)
end
function Base.show(io::IO, g::GlobalMeanPool)
print(io, "GlobalMeanPool()")
end
""" """
MaxPool(k) MaxPool(k)
@ -363,4 +415,4 @@ function Base.show(io::IO, m::MeanPool)
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")") print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
end end
outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad)) outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))

View File

@ -2,7 +2,7 @@
""" """
mae(, y) mae(, y)
Return the mean of absolute error `sum(abs.(ŷ .- y)) / length(y)` Return the mean of absolute error `sum(abs.(ŷ .- y)) / length(y)`
""" """
mae(, y) = sum(abs.( .- y)) * 1 // length(y) mae(, y) = sum(abs.( .- y)) * 1 // length(y)
@ -10,7 +10,7 @@ mae(ŷ, y) = sum(abs.(ŷ .- y)) * 1 // length(y)
""" """
mse(, y) mse(, y)
Return the mean squared error `sum((ŷ .- y).^2) / length(y)`. Return the mean squared error `sum((ŷ .- y).^2) / length(y)`.
""" """
mse(, y) = sum(( .- y).^2) * 1 // length(y) mse(, y) = sum(( .- y).^2) * 1 // length(y)
@ -19,7 +19,7 @@ mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
msle(, y; ϵ=eps(eltype())) msle(, y; ϵ=eps(eltype()))
Returns the mean of the squared logarithmic errors `sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) / length(y)`. Returns the mean of the squared logarithmic errors `sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) / length(y)`.
The `ϵ` term provides numerical stability. The `ϵ` term provides numerical stability.
This error penalizes an under-predicted estimate greater than an over-predicted estimate. This error penalizes an under-predicted estimate greater than an over-predicted estimate.
""" """
@ -60,7 +60,7 @@ end
""" """
crossentropy(, y; weight=1) crossentropy(, y; weight=1)
Return the crossentropy computed as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`. Return the crossentropy computed as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`.
See also [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref). See also [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref).
""" """
@ -69,7 +69,7 @@ crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _cros
""" """
logitcrossentropy(, y; weight=1) logitcrossentropy(, y; weight=1)
Return the crossentropy computed after a [softmax](@ref) operation: Return the crossentropy computed after a [softmax](@ref) operation:
-sum(y .* logsoftmax() .* weight) / size(y, 2) -sum(y .* logsoftmax() .* weight) / size(y, 2)
@ -97,7 +97,7 @@ CuArrays.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1
`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(ŷ), y)` `logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(ŷ), y)`
but it is more numerically stable. but it is more numerically stable.
See also [`binarycrossentropy`](@ref), [`sigmoid`](@ref), [`logsigmoid`](@ref). See also [`binarycrossentropy`](@ref), [`sigmoid`](@ref), [`logsigmoid`](@ref).
""" """
logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ() logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ()
@ -162,7 +162,7 @@ poisson(ŷ, y) = sum(ŷ .- y .* log.(ŷ)) * 1 // size(y,2)
""" """
hinge(, y) hinge(, y)
Measures the loss given the prediction `` and true labels `y` (containing 1 or -1). Measures the loss given the prediction `` and true labels `y` (containing 1 or -1).
Returns `sum((max.(0, 1 .- ŷ .* y))) / size(y, 2)` Returns `sum((max.(0, 1 .- ŷ .* y))) / size(y, 2)`
[Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss) [Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss)
@ -193,10 +193,20 @@ dice_coeff_loss(ŷ, y; smooth=eltype(ŷ)(1.0)) = 1 - (2*sum(y .* ŷ) + smooth
""" """
tversky_loss(, y; β=0.7) tversky_loss(, y; β=0.7)
Used with imbalanced data to give more weightage to False negatives. Used with imbalanced data to give more weightage to False negatives.
Larger β weigh recall higher than precision (by placing more emphasis on false negatives) Larger β weigh recall higher than precision (by placing more emphasis on false negatives)
Returns `1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)` Returns `1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)`
[Tversky loss function for image segmentation using 3D fully convolutional deep networks](https://arxiv.org/pdf/1706.05721.pdf) [Tversky loss function for image segmentation using 3D fully convolutional deep networks](https://arxiv.org/pdf/1706.05721.pdf)
""" """
tversky_loss(, y; β=eltype()(0.7)) = 1 - (sum(y .* ) + 1) / (sum(y .* + β*(1 .- y) .* + (1 - β)*y .* (1 .- )) + 1) tversky_loss(, y; β=eltype()(0.7)) = 1 - (sum(y .* ) + 1) / (sum(y .* + β*(1 .- y) .* + (1 - β)*y .* (1 .- )) + 1)
"""
flatten(x::AbstractArray)
Transforms (w,h,c,b)-shaped input into (w x h x c,b)-shaped output,
by linearizing all values for each element in the batch.
"""
function flatten(x::AbstractArray)
return reshape(x, :, size(x)[end])
end

View File

@ -4,6 +4,10 @@ using Flux: gradient
@testset "Pooling" begin @testset "Pooling" begin
x = randn(Float32, 10, 10, 3, 2) x = randn(Float32, 10, 10, 3, 2)
gmp = GlobalMaxPool()
@test size(gmp(x)) == (1, 1, 3, 2)
gmp = GlobalMeanPool()
@test size(gmp(x)) == (1, 1, 3, 2)
mp = MaxPool((2, 2)) mp = MaxPool((2, 2))
@test mp(x) == maxpool(x, PoolDims(x, 2)) @test mp(x) == maxpool(x, PoolDims(x, 2))
mp = MeanPool((2, 2)) mp = MeanPool((2, 2))

View File

@ -1,6 +1,6 @@
using Test using Test
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
σ, binarycrossentropy, logitbinarycrossentropy σ, binarycrossentropy, logitbinarycrossentropy, flatten
const ϵ = 1e-7 const ϵ = 1e-7
@ -116,3 +116,10 @@ const ϵ = 1e-7
end end
end end
end end
@testset "helpers" begin
@testset "flatten" begin
x = randn(Float32, 10, 10, 3, 2)
@test size(flatten(x)) == (300, 2)
end
end