NNlib docs + misc docs improvements
This commit is contained in:
parent
2dd23574c0
commit
425fcdbe69
|
@ -13,7 +13,8 @@ makedocs(modules=[Flux, NNlib],
|
|||
["Basics" => "models/basics.md",
|
||||
"Recurrence" => "models/recurrence.md",
|
||||
"Regularisation" => "models/regularisation.md",
|
||||
"Model Reference" => "models/layers.md"],
|
||||
"Model Reference" => "models/layers.md",
|
||||
"NNlib" => "models/nnlib.md"],
|
||||
"Training Models" =>
|
||||
["Optimisers" => "training/optimisers.md",
|
||||
"Training" => "training/training.md"],
|
||||
|
|
|
@ -30,7 +30,7 @@ If you define a structured model, like a `Dense` layer or `Chain`, you just need
|
|||
```julia
|
||||
d = Dense(10, 5, σ)
|
||||
d = fmap(cu, d)
|
||||
d.W # Tracked CuArray
|
||||
d.W # CuArray
|
||||
d(cu(rand(10))) # CuArray output
|
||||
|
||||
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||
|
@ -53,7 +53,7 @@ julia> x = rand(10) |> gpu
|
|||
0.511655
|
||||
|
||||
julia> m(x)
|
||||
Tracked 5-element CuArray{Float32,1}:
|
||||
5-element CuArray{Float32,1}:
|
||||
-0.30535
|
||||
⋮
|
||||
-0.618002
|
||||
|
|
|
@ -40,19 +40,6 @@ Maxout
|
|||
SkipConnection
|
||||
```
|
||||
|
||||
## Activation Functions
|
||||
|
||||
Non-linearities that go between layers of your model. Most of these functions are defined in [NNlib](https://github.com/FluxML/NNlib.jl) but are available by default in Flux.
|
||||
|
||||
Note that, unless otherwise stated, activation functions operate on scalars. To apply them to an array you can call `σ.(xs)`, `relu.(xs)` and so on.
|
||||
|
||||
```@docs
|
||||
σ
|
||||
relu
|
||||
leakyrelu
|
||||
elu
|
||||
swish
|
||||
```
|
||||
|
||||
## Normalisation & Regularisation
|
||||
|
||||
|
@ -61,6 +48,7 @@ These layers don't affect the structure of the network but may improve training
|
|||
```@docs
|
||||
BatchNorm
|
||||
Dropout
|
||||
Flux.dropout
|
||||
AlphaDropout
|
||||
LayerNorm
|
||||
GroupNorm
|
||||
|
@ -68,12 +56,12 @@ GroupNorm
|
|||
|
||||
## Cost Functions
|
||||
```@docs
|
||||
mse
|
||||
crossentropy
|
||||
logitcrossentropy
|
||||
binarycrossentropy
|
||||
logitbinarycrossentropy
|
||||
kldivergence
|
||||
poisson
|
||||
hinge
|
||||
Flux.mse
|
||||
Flux.crossentropy
|
||||
Flux.logitcrossentropy
|
||||
Flux.binarycrossentropy
|
||||
Flux.logitbinarycrossentropy
|
||||
Flux.kldivergence
|
||||
Flux.poisson
|
||||
Flux.hinge
|
||||
```
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
## NNlib
|
||||
Flux re-exports all of the functions exported by the [NNlib](https://github.com/FluxML/NNlib.jl) package.
|
||||
|
||||
## Activation Functions
|
||||
Non-linearities that go between layers of your model. Note that, unless otherwise stated, activation functions operate on scalars. To apply them to an array you can call `σ.(xs)`, `relu.(xs)` and so on.
|
||||
|
||||
```@docs
|
||||
NNlib.elu
|
||||
NNlib.gelu
|
||||
NNlib.leakyrelu
|
||||
NNlib.logcosh
|
||||
NNlib.logsigmoid
|
||||
NNlib.sigmoid
|
||||
NNlib.relu
|
||||
NNlib.selu
|
||||
NNlib.softplus
|
||||
NNlib.softsign
|
||||
NNlib.swish
|
||||
```
|
||||
|
||||
## Softmax
|
||||
```@docs
|
||||
NNlib.softmax
|
||||
NNlib.logsoftmax
|
||||
```
|
||||
|
||||
## Pooling
|
||||
```@docs
|
||||
NNlib.maxpool
|
||||
NNlib.meanpool
|
||||
```
|
||||
|
||||
## Convolution
|
||||
```@docs
|
||||
NNlib.conv
|
||||
NNlib.depthwiseconv
|
||||
```
|
|
@ -31,7 +31,7 @@ julia> params(m)
|
|||
param([0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
julia> sum(norm, params(m))
|
||||
26.01749952921026 (tracked)
|
||||
26.01749952921026
|
||||
```
|
||||
|
||||
Here's a larger example with a multi-layer perceptron.
|
||||
|
|
|
@ -7,6 +7,16 @@ _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(s
|
|||
|
||||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||
|
||||
"""
|
||||
dropout(p, dims = :)
|
||||
|
||||
Dropout function. For each input, either sets that input to `0` (with probability
|
||||
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specify the unbroadcasted
|
||||
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
|
||||
used as a regularisation, i.e. it reduces overfitting during training.
|
||||
|
||||
See also [`Dropout`](@ref).
|
||||
"""
|
||||
dropout(x, p; dims = :) = x
|
||||
|
||||
@adjoint function dropout(x, p; dims = :)
|
||||
|
@ -18,10 +28,7 @@ end
|
|||
"""
|
||||
Dropout(p, dims = :)
|
||||
|
||||
A Dropout layer. For each input, either sets that input to `0` (with probability
|
||||
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted
|
||||
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
|
||||
used as a regularisation, i.e. it reduces overfitting during training. see also [`dropout`](@ref).
|
||||
A Dropout layer. In the forward pass, applies the [`dropout`](@ref) function on the input.
|
||||
"""
|
||||
mutable struct Dropout{F,D}
|
||||
p::F
|
||||
|
@ -43,6 +50,7 @@ end
|
|||
|
||||
"""
|
||||
AlphaDropout(p)
|
||||
|
||||
A dropout layer. It is used in Self-Normalizing Neural Networks.
|
||||
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
|
||||
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
using CuArrays
|
||||
using NNlib: logsoftmax, logσ
|
||||
|
||||
# Cost functions
|
||||
"""
|
||||
mse(ŷ, y)
|
||||
|
||||
Return the mean squared error `sum((ŷ .- y).^2) / length(y)`.
|
||||
"""
|
||||
mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
|
||||
|
||||
|
||||
function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Nothing)
|
||||
return -sum(y .* log.(ŷ)) * 1 // size(y, 2)
|
||||
end
|
||||
|
@ -17,10 +19,26 @@ function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Abstr
|
|||
return -sum(y .* log.(ŷ) .* weight) * 1 // size(y, 2)
|
||||
end
|
||||
|
||||
"""
|
||||
crossentropy(ŷ, y; weight=1)
|
||||
|
||||
Return the crossentropy computed as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`.
|
||||
|
||||
See also [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref).
|
||||
"""
|
||||
crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _crossentropy(ŷ, y, weight)
|
||||
|
||||
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
return -sum(y .* logsoftmax(logŷ) .* weight) * 1 // size(y, 2)
|
||||
"""
|
||||
logitcrossentropy(ŷ, y; weight=1)
|
||||
|
||||
Return the crossentropy computed after a [softmax](@ref) operation:
|
||||
|
||||
-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)
|
||||
|
||||
See also [`crossentropy`](@ref), [`binarycrossentropy`](@ref).
|
||||
"""
|
||||
function logitcrossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
return -sum(y .* logsoftmax(ŷ) .* weight) * 1 // size(y, 2)
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -28,11 +46,7 @@ end
|
|||
|
||||
Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability.
|
||||
|
||||
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
|
||||
3-element Array{Float64,1}:
|
||||
1.4244
|
||||
0.352317
|
||||
0.86167
|
||||
Typically, the prediction `ŷ` is given by the output of a [`sigmoid`](@ref) activation.
|
||||
"""
|
||||
binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
||||
|
||||
|
@ -40,27 +54,24 @@ binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ
|
|||
CuArrays.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
||||
|
||||
"""
|
||||
logitbinarycrossentropy(logŷ, y)
|
||||
logitbinarycrossentropy(ŷ, y)
|
||||
|
||||
`logitbinarycrossentropy(logŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(logŷ), y)`
|
||||
`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(ŷ), y)`
|
||||
but it is more numerically stable.
|
||||
|
||||
julia> logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0.])
|
||||
3-element Array{Float64,1}:
|
||||
1.4244
|
||||
0.352317
|
||||
0.86167
|
||||
See also [`binarycrossentropy`](@ref), [`sigmoid`](@ref), [`logsigmoid`](@ref).
|
||||
"""
|
||||
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
|
||||
logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ)
|
||||
|
||||
# Re-definition to fix interaction with CuArrays.
|
||||
CuArrays.@cufunc logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
|
||||
CuArrays.@cufunc logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ)
|
||||
|
||||
"""
|
||||
normalise(x::AbstractArray; dims=1)
|
||||
normalise(x; dims=1)
|
||||
|
||||
Normalises `x` to mean 0 and standard deviation 1, across the dimensions given by `dims`. Defaults to normalising over columns.
|
||||
|
||||
```julia-repl
|
||||
julia> a = reshape(collect(1:9), 3, 3)
|
||||
3×3 Array{Int64,2}:
|
||||
1 4 7
|
||||
|
@ -78,6 +89,7 @@ Normalises `x` to mean 0 and standard deviation 1, across the dimensions given b
|
|||
-1.22474 0.0 1.22474
|
||||
-1.22474 0.0 1.22474
|
||||
-1.22474 0.0 1.22474
|
||||
```
|
||||
"""
|
||||
function normalise(x::AbstractArray; dims=1)
|
||||
μ′ = mean(x, dims = dims)
|
||||
|
@ -87,6 +99,7 @@ end
|
|||
|
||||
"""
|
||||
kldivergence(ŷ, y)
|
||||
|
||||
KLDivergence is a measure of how much one probability distribution is different from the other.
|
||||
It is always non-negative and zero only when both the distributions are equal everywhere.
|
||||
[KL Divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence).
|
||||
|
@ -99,6 +112,7 @@ end
|
|||
|
||||
"""
|
||||
poisson(ŷ, y)
|
||||
|
||||
Poisson loss function is a measure of how the predicted distribution diverges from the expected distribution.
|
||||
[Poisson Loss](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
|
||||
"""
|
||||
|
@ -106,7 +120,8 @@ poisson(ŷ, y) = sum(ŷ .- y .* log.(ŷ)) *1 // size(y,2)
|
|||
|
||||
"""
|
||||
hinge(ŷ, y)
|
||||
Measures the loss given the prediction ŷ and true labels y(containing 1 or -1).
|
||||
|
||||
Measures the loss given the prediction `ŷ` and true labels `y` (containing 1 or -1).
|
||||
[Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss).
|
||||
"""
|
||||
hinge(ŷ, y) = sum(max.(0, 1 .- ŷ .* y)) *1 // size(y,2)
|
||||
|
|
Loading…
Reference in New Issue