commit
b637311642
|
@ -5,14 +5,16 @@ These core layers form the foundation of almost all neural networks.
|
|||
```@docs
|
||||
Chain
|
||||
Dense
|
||||
```
|
||||
|
||||
## Convolution and Pooling Layers
|
||||
|
||||
These layers are used to build convolutional neural networks (CNNs).
|
||||
|
||||
```@docs
|
||||
Conv
|
||||
MaxPool
|
||||
MeanPool
|
||||
```
|
||||
|
||||
## Additional Convolution Layers
|
||||
|
||||
```@docs
|
||||
DepthwiseConv
|
||||
ConvTranspose
|
||||
```
|
||||
|
@ -28,6 +30,25 @@ GRU
|
|||
Flux.Recur
|
||||
```
|
||||
|
||||
## Other General Purpose Layers
|
||||
These are marginally more obscure than the Basic Layers.
|
||||
But in contrast to the layers described in the other sections are not readily grouped around a particular purpose (e.g. CNNs or RNNs).
|
||||
|
||||
```@docs
|
||||
Maxout
|
||||
```
|
||||
|
||||
# Normalisation & Regularisation
|
||||
|
||||
These layers don't affect the structure of the network but may improve training times or reduce overfitting.
|
||||
|
||||
```@docs
|
||||
Flux.testmode!
|
||||
BatchNorm
|
||||
Dropout
|
||||
LayerNorm
|
||||
```
|
||||
|
||||
## Activation Functions
|
||||
|
||||
Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux.
|
||||
|
|
|
@ -6,8 +6,10 @@ using Base: tail
|
|||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
|
||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
|
||||
export Chain, Dense, Maxout,
|
||||
RNN, LSTM, GRU,
|
||||
Conv, ConvTranspose, MaxPool, MeanPool, DepthwiseConv,
|
||||
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
|
||||
params, mapleaves, cpu, gpu, f32, f64
|
||||
|
||||
@reexport using NNlib
|
||||
|
|
|
@ -88,6 +88,14 @@ function Base.show(io::IO, l::Dense)
|
|||
print(io, ")")
|
||||
end
|
||||
|
||||
# Try to avoid hitting generic matmul in some simple cases
|
||||
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
|
||||
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
invoke(a, Tuple{AbstractArray}, x)
|
||||
|
||||
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
||||
"""
|
||||
Diagonal(in::Integer)
|
||||
|
||||
|
@ -117,10 +125,48 @@ function Base.show(io::IO, l::Diagonal)
|
|||
print(io, "Diagonal(", length(l.α), ")")
|
||||
end
|
||||
|
||||
# Try to avoid hitting generic matmul in some simple cases
|
||||
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
|
||||
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
invoke(a, Tuple{AbstractArray}, x)
|
||||
|
||||
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
"""
|
||||
Maxout(over)
|
||||
|
||||
`Maxout` is a neural network layer, which has a number of internal layers,
|
||||
which all have the same input, and the maxout returns the elementwise maximium
|
||||
of the internal layers' outputs.
|
||||
|
||||
Maxout over linear dense layers satisfies the univeral approximation theorem.
|
||||
|
||||
Reference:
|
||||
Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, and Yoshua Bengio.
|
||||
2013. Maxout networks.
|
||||
In Proceedings of the 30th International Conference on International Conference on Machine Learning - Volume 28 (ICML'13),
|
||||
Sanjoy Dasgupta and David McAllester (Eds.), Vol. 28. JMLR.org III-1319-III-1327.
|
||||
https://arxiv.org/pdf/1302.4389.pdf
|
||||
"""
|
||||
struct Maxout{FS<:Tuple}
|
||||
over::FS
|
||||
end
|
||||
|
||||
"""
|
||||
Maxout(f, n_alts)
|
||||
|
||||
Constructs a Maxout layer over `n_alts` instances of the layer given by `f`.
|
||||
The function takes no arguement and should return some callable layer.
|
||||
Conventionally this is a linear dense layer.
|
||||
|
||||
For example the following example which
|
||||
will construct a `Maxout` layer over 4 internal dense linear layers,
|
||||
each identical in structure (784 inputs, 128 outputs).
|
||||
```julia
|
||||
insize = 784
|
||||
outsie = 128
|
||||
Maxout(()->Dense(insize, outsize), 4)
|
||||
```
|
||||
"""
|
||||
function Maxout(f, n_alts)
|
||||
over = Tuple(f() for _ in 1:n_alts)
|
||||
return Maxout(over)
|
||||
end
|
||||
|
||||
function (mo::Maxout)(input::AbstractArray)
|
||||
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
|
||||
end
|
||||
|
|
|
@ -30,4 +30,28 @@ using Test, Random
|
|||
@test Flux.Diagonal(2)([1,2]) == [1,2]
|
||||
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
|
||||
end
|
||||
|
||||
@testset "Maxout" begin
|
||||
# Note that the normal common usage of Maxout is as per the docstring
|
||||
# These are abnormal constructors used for testing purposes
|
||||
|
||||
@testset "Constructor" begin
|
||||
mo = Maxout(() -> identity, 4)
|
||||
input = rand(40)
|
||||
@test mo(input) == input
|
||||
end
|
||||
|
||||
@testset "simple alternatives" begin
|
||||
mo = Maxout((x -> x, x -> 2x, x -> 0.5x))
|
||||
input = rand(40)
|
||||
@test mo(input) == 2*input
|
||||
end
|
||||
|
||||
@testset "complex alternatives" begin
|
||||
mo = Maxout((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x))
|
||||
input = [3.0 2.0]
|
||||
target = [0.5, 0.7].*input
|
||||
@test mo(input) == target
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue