Merge pull request #647 from oxinabox/ox/maxout

Add MaxOut layer
This commit is contained in:
Mike J Innes 2019-03-22 12:18:53 +00:00 committed by GitHub
commit b637311642
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 106 additions and 13 deletions

View File

@ -5,14 +5,16 @@ These core layers form the foundation of almost all neural networks.
```@docs
Chain
Dense
```
## Convolution and Pooling Layers
These layers are used to build convolutional neural networks (CNNs).
```@docs
Conv
MaxPool
MeanPool
```
## Additional Convolution Layers
```@docs
DepthwiseConv
ConvTranspose
```
@ -28,6 +30,25 @@ GRU
Flux.Recur
```
## Other General Purpose Layers
These are marginally more obscure than the Basic Layers.
But in contrast to the layers described in the other sections are not readily grouped around a particular purpose (e.g. CNNs or RNNs).
```@docs
Maxout
```
# Normalisation & Regularisation
These layers don't affect the structure of the network but may improve training times or reduce overfitting.
```@docs
Flux.testmode!
BatchNorm
Dropout
LayerNorm
```
## Activation Functions
Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux.

View File

@ -6,8 +6,10 @@ using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
export Chain, Dense, Maxout,
RNN, LSTM, GRU,
Conv, ConvTranspose, MaxPool, MeanPool, DepthwiseConv,
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
params, mapleaves, cpu, gpu, f32, f64
@reexport using NNlib

View File

@ -88,6 +88,14 @@ function Base.show(io::IO, l::Dense)
print(io, ")")
end
# Try to avoid hitting generic matmul in some simple cases
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))
"""
Diagonal(in::Integer)
@ -117,10 +125,48 @@ function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", length(l.α), ")")
end
# Try to avoid hitting generic matmul in some simple cases
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))
"""
Maxout(over)
`Maxout` is a neural network layer, which has a number of internal layers,
which all have the same input, and the maxout returns the elementwise maximium
of the internal layers' outputs.
Maxout over linear dense layers satisfies the univeral approximation theorem.
Reference:
Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, and Yoshua Bengio.
2013. Maxout networks.
In Proceedings of the 30th International Conference on International Conference on Machine Learning - Volume 28 (ICML'13),
Sanjoy Dasgupta and David McAllester (Eds.), Vol. 28. JMLR.org III-1319-III-1327.
https://arxiv.org/pdf/1302.4389.pdf
"""
struct Maxout{FS<:Tuple}
over::FS
end
"""
Maxout(f, n_alts)
Constructs a Maxout layer over `n_alts` instances of the layer given by `f`.
The function takes no arguement and should return some callable layer.
Conventionally this is a linear dense layer.
For example the following example which
will construct a `Maxout` layer over 4 internal dense linear layers,
each identical in structure (784 inputs, 128 outputs).
```julia
insize = 784
outsie = 128
Maxout(()->Dense(insize, outsize), 4)
```
"""
function Maxout(f, n_alts)
over = Tuple(f() for _ in 1:n_alts)
return Maxout(over)
end
function (mo::Maxout)(input::AbstractArray)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
end

View File

@ -30,4 +30,28 @@ using Test, Random
@test Flux.Diagonal(2)([1,2]) == [1,2]
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
end
@testset "Maxout" begin
# Note that the normal common usage of Maxout is as per the docstring
# These are abnormal constructors used for testing purposes
@testset "Constructor" begin
mo = Maxout(() -> identity, 4)
input = rand(40)
@test mo(input) == input
end
@testset "simple alternatives" begin
mo = Maxout((x -> x, x -> 2x, x -> 0.5x))
input = rand(40)
@test mo(input) == 2*input
end
@testset "complex alternatives" begin
mo = Maxout((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x))
input = [3.0 2.0]
target = [0.5, 0.7].*input
@test mo(input) == target
end
end
end