Merge #950
950: added GlobalMaxPool, GlobalMeanPool, and flatten layers r=CarloLucibello a=gartangh Co-authored-by: Garben Tanghe <garben.tanghe@gmail.com>
This commit is contained in:
commit
d4cf1436df
|
@ -14,10 +14,13 @@ These layers are used to build convolutional neural networks (CNNs).
|
||||||
```@docs
|
```@docs
|
||||||
Conv
|
Conv
|
||||||
MaxPool
|
MaxPool
|
||||||
|
GlobalMaxPool
|
||||||
MeanPool
|
MeanPool
|
||||||
|
GlobalMeanPool
|
||||||
DepthwiseConv
|
DepthwiseConv
|
||||||
ConvTranspose
|
ConvTranspose
|
||||||
CrossCor
|
CrossCor
|
||||||
|
flatten
|
||||||
```
|
```
|
||||||
|
|
||||||
## Recurrent Layers
|
## Recurrent Layers
|
||||||
|
|
|
@ -10,7 +10,8 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
|
||||||
|
|
||||||
export gradient
|
export gradient
|
||||||
|
|
||||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose,
|
||||||
|
GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, flatten,
|
||||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||||
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode!
|
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode!
|
||||||
|
|
||||||
|
|
|
@ -95,8 +95,9 @@ outdims(l::Conv, isize) =
|
||||||
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
|
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
|
||||||
`in` and `out` specify the number of input and output channels respectively.
|
`in` and `out` specify the number of input and output channels respectively.
|
||||||
|
|
||||||
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
Data should be stored in WHCN order (width, height, # channels, # batches).
|
||||||
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||||
|
and a batch of 50 would be a `100×100×3×50` array.
|
||||||
|
|
||||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||||
"""
|
"""
|
||||||
|
@ -171,8 +172,9 @@ Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||||
`in` and `out` specify the number of input and output channels respectively.
|
`in` and `out` specify the number of input and output channels respectively.
|
||||||
Note that `out` must be an integer multiple of `in`.
|
Note that `out` must be an integer multiple of `in`.
|
||||||
|
|
||||||
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
Data should be stored in WHCN order (width, height, # channels, # batches).
|
||||||
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||||
|
and a batch of 50 would be a `100×100×3×50` array.
|
||||||
|
|
||||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||||
"""
|
"""
|
||||||
|
@ -304,6 +306,56 @@ end
|
||||||
outdims(l::CrossCor, isize) =
|
outdims(l::CrossCor, isize) =
|
||||||
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
|
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
|
||||||
|
|
||||||
|
"""
|
||||||
|
GlobalMaxPool()
|
||||||
|
|
||||||
|
Global max pooling layer.
|
||||||
|
|
||||||
|
Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output,
|
||||||
|
by performing max pooling on the complete (w,h)-shaped feature maps.
|
||||||
|
"""
|
||||||
|
struct GlobalMaxPool end
|
||||||
|
|
||||||
|
function (g::GlobalMaxPool)(x)
|
||||||
|
# Input size
|
||||||
|
x_size = size(x)
|
||||||
|
# Kernel size
|
||||||
|
k = x_size[1:end-2]
|
||||||
|
# Pooling dimensions
|
||||||
|
pdims = PoolDims(x, k)
|
||||||
|
|
||||||
|
return maxpool(x, pdims)
|
||||||
|
end
|
||||||
|
|
||||||
|
function Base.show(io::IO, g::GlobalMaxPool)
|
||||||
|
print(io, "GlobalMaxPool()")
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
GlobalMeanPool()
|
||||||
|
|
||||||
|
Global mean pooling layer.
|
||||||
|
|
||||||
|
Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output,
|
||||||
|
by performing mean pooling on the complete (w,h)-shaped feature maps.
|
||||||
|
"""
|
||||||
|
struct GlobalMeanPool end
|
||||||
|
|
||||||
|
function (g::GlobalMeanPool)(x)
|
||||||
|
# Input size
|
||||||
|
x_size = size(x)
|
||||||
|
# Kernel size
|
||||||
|
k = x_size[1:end-2]
|
||||||
|
# Pooling dimensions
|
||||||
|
pdims = PoolDims(x, k)
|
||||||
|
|
||||||
|
return meanpool(x, pdims)
|
||||||
|
end
|
||||||
|
|
||||||
|
function Base.show(io::IO, g::GlobalMeanPool)
|
||||||
|
print(io, "GlobalMeanPool()")
|
||||||
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
MaxPool(k)
|
MaxPool(k)
|
||||||
|
|
||||||
|
@ -363,4 +415,4 @@ function Base.show(io::IO, m::MeanPool)
|
||||||
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
|
outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
"""
|
"""
|
||||||
mae(ŷ, y)
|
mae(ŷ, y)
|
||||||
|
|
||||||
Return the mean of absolute error `sum(abs.(ŷ .- y)) / length(y)`
|
Return the mean of absolute error `sum(abs.(ŷ .- y)) / length(y)`
|
||||||
"""
|
"""
|
||||||
mae(ŷ, y) = sum(abs.(ŷ .- y)) * 1 // length(y)
|
mae(ŷ, y) = sum(abs.(ŷ .- y)) * 1 // length(y)
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ mae(ŷ, y) = sum(abs.(ŷ .- y)) * 1 // length(y)
|
||||||
"""
|
"""
|
||||||
mse(ŷ, y)
|
mse(ŷ, y)
|
||||||
|
|
||||||
Return the mean squared error `sum((ŷ .- y).^2) / length(y)`.
|
Return the mean squared error `sum((ŷ .- y).^2) / length(y)`.
|
||||||
"""
|
"""
|
||||||
mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
|
mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
|
||||||
msle(ŷ, y; ϵ=eps(eltype(ŷ)))
|
msle(ŷ, y; ϵ=eps(eltype(ŷ)))
|
||||||
|
|
||||||
Returns the mean of the squared logarithmic errors `sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) / length(y)`.
|
Returns the mean of the squared logarithmic errors `sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) / length(y)`.
|
||||||
The `ϵ` term provides numerical stability.
|
The `ϵ` term provides numerical stability.
|
||||||
|
|
||||||
This error penalizes an under-predicted estimate greater than an over-predicted estimate.
|
This error penalizes an under-predicted estimate greater than an over-predicted estimate.
|
||||||
"""
|
"""
|
||||||
|
@ -60,7 +60,7 @@ end
|
||||||
"""
|
"""
|
||||||
crossentropy(ŷ, y; weight=1)
|
crossentropy(ŷ, y; weight=1)
|
||||||
|
|
||||||
Return the crossentropy computed as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`.
|
Return the crossentropy computed as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`.
|
||||||
|
|
||||||
See also [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref).
|
See also [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref).
|
||||||
"""
|
"""
|
||||||
|
@ -69,7 +69,7 @@ crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _cros
|
||||||
"""
|
"""
|
||||||
logitcrossentropy(ŷ, y; weight=1)
|
logitcrossentropy(ŷ, y; weight=1)
|
||||||
|
|
||||||
Return the crossentropy computed after a [softmax](@ref) operation:
|
Return the crossentropy computed after a [softmax](@ref) operation:
|
||||||
|
|
||||||
-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)
|
-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)
|
||||||
|
|
||||||
|
@ -97,7 +97,7 @@ CuArrays.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1
|
||||||
`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(ŷ), y)`
|
`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(ŷ), y)`
|
||||||
but it is more numerically stable.
|
but it is more numerically stable.
|
||||||
|
|
||||||
See also [`binarycrossentropy`](@ref), [`sigmoid`](@ref), [`logsigmoid`](@ref).
|
See also [`binarycrossentropy`](@ref), [`sigmoid`](@ref), [`logsigmoid`](@ref).
|
||||||
"""
|
"""
|
||||||
logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ)
|
logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ)
|
||||||
|
|
||||||
|
@ -162,7 +162,7 @@ poisson(ŷ, y) = sum(ŷ .- y .* log.(ŷ)) * 1 // size(y,2)
|
||||||
"""
|
"""
|
||||||
hinge(ŷ, y)
|
hinge(ŷ, y)
|
||||||
|
|
||||||
Measures the loss given the prediction `ŷ` and true labels `y` (containing 1 or -1).
|
Measures the loss given the prediction `ŷ` and true labels `y` (containing 1 or -1).
|
||||||
Returns `sum((max.(0, 1 .- ŷ .* y))) / size(y, 2)`
|
Returns `sum((max.(0, 1 .- ŷ .* y))) / size(y, 2)`
|
||||||
|
|
||||||
[Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss)
|
[Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss)
|
||||||
|
@ -193,10 +193,20 @@ dice_coeff_loss(ŷ, y; smooth=eltype(ŷ)(1.0)) = 1 - (2*sum(y .* ŷ) + smooth
|
||||||
"""
|
"""
|
||||||
tversky_loss(ŷ, y; β=0.7)
|
tversky_loss(ŷ, y; β=0.7)
|
||||||
|
|
||||||
Used with imbalanced data to give more weightage to False negatives.
|
Used with imbalanced data to give more weightage to False negatives.
|
||||||
Larger β weigh recall higher than precision (by placing more emphasis on false negatives)
|
Larger β weigh recall higher than precision (by placing more emphasis on false negatives)
|
||||||
Returns `1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)`
|
Returns `1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)`
|
||||||
|
|
||||||
[Tversky loss function for image segmentation using 3D fully convolutional deep networks](https://arxiv.org/pdf/1706.05721.pdf)
|
[Tversky loss function for image segmentation using 3D fully convolutional deep networks](https://arxiv.org/pdf/1706.05721.pdf)
|
||||||
"""
|
"""
|
||||||
tversky_loss(ŷ, y; β=eltype(ŷ)(0.7)) = 1 - (sum(y .* ŷ) + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
|
tversky_loss(ŷ, y; β=eltype(ŷ)(0.7)) = 1 - (sum(y .* ŷ) + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
|
||||||
|
|
||||||
|
"""
|
||||||
|
flatten(x::AbstractArray)
|
||||||
|
|
||||||
|
Transforms (w,h,c,b)-shaped input into (w x h x c,b)-shaped output,
|
||||||
|
by linearizing all values for each element in the batch.
|
||||||
|
"""
|
||||||
|
function flatten(x::AbstractArray)
|
||||||
|
return reshape(x, :, size(x)[end])
|
||||||
|
end
|
||||||
|
|
|
@ -4,6 +4,10 @@ using Flux: gradient
|
||||||
|
|
||||||
@testset "Pooling" begin
|
@testset "Pooling" begin
|
||||||
x = randn(Float32, 10, 10, 3, 2)
|
x = randn(Float32, 10, 10, 3, 2)
|
||||||
|
gmp = GlobalMaxPool()
|
||||||
|
@test size(gmp(x)) == (1, 1, 3, 2)
|
||||||
|
gmp = GlobalMeanPool()
|
||||||
|
@test size(gmp(x)) == (1, 1, 3, 2)
|
||||||
mp = MaxPool((2, 2))
|
mp = MaxPool((2, 2))
|
||||||
@test mp(x) == maxpool(x, PoolDims(x, 2))
|
@test mp(x) == maxpool(x, PoolDims(x, 2))
|
||||||
mp = MeanPool((2, 2))
|
mp = MeanPool((2, 2))
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
using Test
|
using Test
|
||||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||||
σ, binarycrossentropy, logitbinarycrossentropy
|
σ, binarycrossentropy, logitbinarycrossentropy, flatten
|
||||||
|
|
||||||
const ϵ = 1e-7
|
const ϵ = 1e-7
|
||||||
|
|
||||||
|
@ -116,3 +116,10 @@ const ϵ = 1e-7
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "helpers" begin
|
||||||
|
@testset "flatten" begin
|
||||||
|
x = randn(Float32, 10, 10, 3, 2)
|
||||||
|
@test size(flatten(x)) == (300, 2)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
Loading…
Reference in New Issue