NNlib docs + misc docs improvements

This commit is contained in:
Carlo Lucibello 2020-02-29 11:14:48 +01:00
parent 2dd23574c0
commit 425fcdbe69
7 changed files with 115 additions and 66 deletions

View File

@ -13,7 +13,8 @@ makedocs(modules=[Flux, NNlib],
["Basics" => "models/basics.md", ["Basics" => "models/basics.md",
"Recurrence" => "models/recurrence.md", "Recurrence" => "models/recurrence.md",
"Regularisation" => "models/regularisation.md", "Regularisation" => "models/regularisation.md",
"Model Reference" => "models/layers.md"], "Model Reference" => "models/layers.md",
"NNlib" => "models/nnlib.md"],
"Training Models" => "Training Models" =>
["Optimisers" => "training/optimisers.md", ["Optimisers" => "training/optimisers.md",
"Training" => "training/training.md"], "Training" => "training/training.md"],

View File

@ -30,7 +30,7 @@ If you define a structured model, like a `Dense` layer or `Chain`, you just need
```julia ```julia
d = Dense(10, 5, σ) d = Dense(10, 5, σ)
d = fmap(cu, d) d = fmap(cu, d)
d.W # Tracked CuArray d.W # CuArray
d(cu(rand(10))) # CuArray output d(cu(rand(10))) # CuArray output
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax) m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
@ -53,7 +53,7 @@ julia> x = rand(10) |> gpu
0.511655 0.511655
julia> m(x) julia> m(x)
Tracked 5-element CuArray{Float32,1}: 5-element CuArray{Float32,1}:
-0.30535 -0.30535
-0.618002 -0.618002

View File

@ -40,19 +40,6 @@ Maxout
SkipConnection SkipConnection
``` ```
## Activation Functions
Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux.
Note that, unless otherwise stated, activation functions operate on scalars. To apply them to an array you can call `σ.(xs)`, `relu.(xs)` and so on.
```@docs
σ
relu
leakyrelu
elu
swish
```
## Normalisation & Regularisation ## Normalisation & Regularisation
@ -61,6 +48,7 @@ These layers don't affect the structure of the network but may improve training
```@docs ```@docs
BatchNorm BatchNorm
Dropout Dropout
Flux.dropout
AlphaDropout AlphaDropout
LayerNorm LayerNorm
GroupNorm GroupNorm
@ -68,12 +56,12 @@ GroupNorm
## Cost Functions ## Cost Functions
```@docs ```@docs
mse Flux.mse
crossentropy Flux.crossentropy
logitcrossentropy Flux.logitcrossentropy
binarycrossentropy Flux.binarycrossentropy
logitbinarycrossentropy Flux.logitbinarycrossentropy
kldivergence Flux.kldivergence
poisson Flux.poisson
hinge Flux.hinge
``` ```

37
docs/src/models/nnlib.md Normal file
View File

@ -0,0 +1,37 @@
## NNlib
Flux re-exports all of the functions exported by the [NNlib](https://github.com/FluxML/NNlib.jl) package.
## Activation Functions
Non-linearities that go between layers of your model. Note that, unless otherwise stated, activation functions operate on scalars. To apply them to an array you can call `σ.(xs)`, `relu.(xs)` and so on.
```@docs
NNlib.elu
NNlib.gelu
NNlib.leakyrelu
NNlib.logcosh
NNlib.logsigmoid
NNlib.sigmoid
NNlib.relu
NNlib.selu
NNlib.softplus
NNlib.softsign
NNlib.swish
```
## Softmax
```@docs
NNlib.softmax
NNlib.logsoftmax
```
## Pooling
```@docs
NNlib.maxpool
NNlib.meanpool
```
## Convolution
```@docs
NNlib.conv
NNlib.depthwiseconv
```

View File

@ -31,7 +31,7 @@ julia> params(m)
param([0.0, 0.0, 0.0, 0.0, 0.0]) param([0.0, 0.0, 0.0, 0.0, 0.0])
julia> sum(norm, params(m)) julia> sum(norm, params(m))
26.01749952921026 (tracked) 26.01749952921026
``` ```
Here's a larger example with a multi-layer perceptron. Here's a larger example with a multi-layer perceptron.
@ -52,7 +52,7 @@ One can also easily add per-layer regularisation via the `activations` function:
```julia ```julia
julia> using Flux: activations julia> using Flux: activations
julia> c = Chain(Dense(10,5,σ),Dense(5,2),softmax) julia> c = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax) Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
julia> activations(c, rand(10)) julia> activations(c, rand(10))

View File

@ -7,6 +7,16 @@ _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(s
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) _dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
"""
dropout(p, dims = :)
Dropout function. For each input, either sets that input to `0` (with probability
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specify the unbroadcasted
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
used as a regularisation, i.e. it reduces overfitting during training.
See also [`Dropout`](@ref).
"""
dropout(x, p; dims = :) = x dropout(x, p; dims = :) = x
@adjoint function dropout(x, p; dims = :) @adjoint function dropout(x, p; dims = :)
@ -18,10 +28,7 @@ end
""" """
Dropout(p, dims = :) Dropout(p, dims = :)
A Dropout layer. For each input, either sets that input to `0` (with probability A Dropout layer. In the forward pass, applies the [`dropout`](@ref) function on the input.
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
used as a regularisation, i.e. it reduces overfitting during training. see also [`dropout`](@ref).
""" """
mutable struct Dropout{F,D} mutable struct Dropout{F,D}
p::F p::F
@ -43,6 +50,7 @@ end
""" """
AlphaDropout(p) AlphaDropout(p)
A dropout layer. It is used in Self-Normalizing Neural Networks. A dropout layer. It is used in Self-Normalizing Neural Networks.
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf) (https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
The AlphaDropout layer ensures that mean and variance of activations remains the same as before. The AlphaDropout layer ensures that mean and variance of activations remains the same as before.

View File

@ -1,10 +1,12 @@
using CuArrays
using NNlib: logsoftmax, logσ
# Cost functions # Cost functions
"""
mse(, y)
Return the mean squared error `sum((ŷ .- y).^2) / length(y)`.
"""
mse(, y) = sum(( .- y).^2) * 1 // length(y) mse(, y) = sum(( .- y).^2) * 1 // length(y)
function _crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat, weight::Nothing) function _crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat, weight::Nothing)
return -sum(y .* log.()) * 1 // size(y, 2) return -sum(y .* log.()) * 1 // size(y, 2)
end end
@ -17,10 +19,26 @@ function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Abstr
return -sum(y .* log.() .* weight) * 1 // size(y, 2) return -sum(y .* log.() .* weight) * 1 // size(y, 2)
end end
"""
crossentropy(, y; weight=1)
Return the crossentropy computed as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`.
See also [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref).
"""
crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _crossentropy(, y, weight) crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _crossentropy(, y, weight)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) """
return -sum(y .* logsoftmax(logŷ) .* weight) * 1 // size(y, 2) logitcrossentropy(, y; weight=1)
Return the crossentropy computed after a [softmax](@ref) operation:
-sum(y .* logsoftmax() .* weight) / size(y, 2)
See also [`crossentropy`](@ref), [`binarycrossentropy`](@ref).
"""
function logitcrossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* logsoftmax() .* weight) * 1 // size(y, 2)
end end
""" """
@ -28,11 +46,7 @@ end
Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability. Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability.
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.]) Typically, the prediction `` is given by the output of a [`sigmoid`](@ref) activation.
3-element Array{Float64,1}:
1.4244
0.352317
0.86167
""" """
binarycrossentropy(, y; ϵ=eps()) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ) binarycrossentropy(, y; ϵ=eps()) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ)
@ -40,44 +54,42 @@ binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ
CuArrays.@cufunc binarycrossentropy(, y; ϵ=eps()) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ) CuArrays.@cufunc binarycrossentropy(, y; ϵ=eps()) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ)
""" """
logitbinarycrossentropy(logŷ, y) logitbinarycrossentropy(ŷ, y)
`logitbinarycrossentropy(logŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(logŷ), y)` `logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(ŷ), y)`
but it is more numerically stable. but it is more numerically stable.
julia> logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0.]) See also [`binarycrossentropy`](@ref), [`sigmoid`](@ref), [`logsigmoid`](@ref).
3-element Array{Float64,1}:
1.4244
0.352317
0.86167
""" """
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(log) logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ()
# Re-definition to fix interaction with CuArrays. # Re-definition to fix interaction with CuArrays.
CuArrays.@cufunc logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(log) CuArrays.@cufunc logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ()
""" """
normalise(x::AbstractArray; dims=1) normalise(x; dims=1)
Normalises `x` to mean 0 and standard deviation 1, across the dimensions given by `dims`. Defaults to normalising over columns. Normalises `x` to mean 0 and standard deviation 1, across the dimensions given by `dims`. Defaults to normalising over columns.
julia> a = reshape(collect(1:9), 3, 3) ```julia-repl
3×3 Array{Int64,2}: julia> a = reshape(collect(1:9), 3, 3)
1 4 7 3×3 Array{Int64,2}:
2 5 8 1 4 7
3 6 9 2 5 8
3 6 9
julia> normalise(a) julia> normalise(a)
3×3 Array{Float64,2}: 3×3 Array{Float64,2}:
-1.22474 -1.22474 -1.22474 -1.22474 -1.22474 -1.22474
0.0 0.0 0.0 0.0 0.0 0.0
1.22474 1.22474 1.22474 1.22474 1.22474 1.22474
julia> normalise(a, dims=2) julia> normalise(a, dims=2)
3×3 Array{Float64,2}: 3×3 Array{Float64,2}:
-1.22474 0.0 1.22474 -1.22474 0.0 1.22474
-1.22474 0.0 1.22474 -1.22474 0.0 1.22474
-1.22474 0.0 1.22474 -1.22474 0.0 1.22474
```
""" """
function normalise(x::AbstractArray; dims=1) function normalise(x::AbstractArray; dims=1)
μ′ = mean(x, dims = dims) μ′ = mean(x, dims = dims)
@ -87,6 +99,7 @@ end
""" """
kldivergence(, y) kldivergence(, y)
KLDivergence is a measure of how much one probability distribution is different from the other. KLDivergence is a measure of how much one probability distribution is different from the other.
It is always non-negative and zero only when both the distributions are equal everywhere. It is always non-negative and zero only when both the distributions are equal everywhere.
[KL Divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence). [KL Divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence).
@ -99,6 +112,7 @@ end
""" """
poisson(, y) poisson(, y)
Poisson loss function is a measure of how the predicted distribution diverges from the expected distribution. Poisson loss function is a measure of how the predicted distribution diverges from the expected distribution.
[Poisson Loss](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson). [Poisson Loss](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
""" """
@ -106,7 +120,8 @@ poisson(ŷ, y) = sum(ŷ .- y .* log.(ŷ)) *1 // size(y,2)
""" """
hinge(, y) hinge(, y)
Measures the loss given the prediction and true labels y(containing 1 or -1).
Measures the loss given the prediction `` and true labels `y` (containing 1 or -1).
[Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss). [Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss).
""" """
hinge(, y) = sum(max.(0, 1 .- .* y)) *1 // size(y,2) hinge(, y) = sum(max.(0, 1 .- .* y)) *1 // size(y,2)