Add MaxOut layer

This commit is contained in:
Lyndon White 2019-02-27 12:04:59 +00:00
parent 79de829fdc
commit fcc3ec471a
3 changed files with 75 additions and 2 deletions

View File

@ -6,8 +6,10 @@ using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
export Chain, Dense, MaxOut,
RNN, LSTM, GRU,
Conv, ConvTranspose, MaxPool, MeanPool, DepthwiseConv,
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
params, mapleaves, cpu, gpu, f32, f64
@reexport using NNlib

View File

@ -125,3 +125,47 @@ function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", length(l.α), ")")
end
"""
MaxOut(over)
MaxOut is a neural network layer, which has a number of internal layers,
which all have the same input, and the max out returns the elementwise maximium
of the internal layers' outputs.
Maxout over linear dense layers satisfies the univeral approximation theorem.
Reference:
Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, and Yoshua Bengio.
2013. Maxout networks.
In Proceedings of the 30th International Conference on International Conference on Machine Learning - Volume 28 (ICML'13),
Sanjoy Dasgupta and David McAllester (Eds.), Vol. 28. JMLR.org III-1319-III-1327.
https://arxiv.org/pdf/1302.4389.pdf
"""
struct MaxOut{FS<:Tuple}
over::FS
end
"""
MaxOut(f, n_alts, args...; kwargs...)
Constructs a MaxOut layer over `n_alts` instances of the layer given by `f`.
All other arguements (`args` & `kwargs`) are passed to the constructor `f`.
For example the followeExample usage
will construct a MaxOut layer over 4 dense linear layers,
each identical in structure (784 inputs, 128 outputs).
```julia
insize = 784
outsie = 128
MaxOut(Dense, 4, insize, outsize)
```
"""
function MaxOut(f, n_alts, args...; kwargs...)
over = Tuple(f(args...; kwargs...) for _ in 1:n_alts)
return MaxOut(over)
end
function (mo::MaxOut)(input::AbstractArray)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
end

View File

@ -30,4 +30,31 @@ using Test, Random
@test Flux.Diagonal(2)([1,2]) == [1,2]
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
end
@testset "MaxOut" begin
# Note that the normal common usage of MaxOut is as per the docstring
# These are abnormal constructors used for testing purposes
@testset "Constructor" begin
mo = MaxOut(() -> identity, 4)
input = rand(40)
@test mo(input) == input
end
@testset "simple alternatives" begin
mo = MaxOut((x -> x, x -> 2x, x -> 0.5x))
input = rand(40)
@test mo(input) == 2*input
end
@testset "complex alternatives" begin
mo = MaxOut((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x))
input = [3.0 2.0]
target = [0.5, 0.7].*input
@test mo(input) == target
end
end
end