Further docstring improvements in src/

Some had to be re-done after the rebase
This commit is contained in:
janEbert 2020-04-04 22:59:45 +02:00
parent 64ce32ddcf
commit 2ce5f6d9bf
5 changed files with 100 additions and 102 deletions

View File

@ -183,18 +183,11 @@ outdims(l::Diagonal, isize) = (length(l.α),)
"""
Maxout(over)
`Maxout` is a neural network layer which has a number of internal layers
which all receive the same input. The layer returns the elementwise maximium
of the internal layers' outputs.
The [Maxout](https://arxiv.org/pdf/1302.4389.pdf) layer has a number of
internal layers which all receive the same input. It returns the elementwise
maximum of the internal layers' outputs.
Maxout over linear dense layers satisfies the univeral approximation theorem.
Reference:
Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, and Yoshua Bengio.
2013. Maxout networks.
In Proceedings of the 30th International Conference on International Conference on Machine Learning - Volume 28 (ICML'13),
Sanjoy Dasgupta and David McAllester (Eds.), Vol. 28. JMLR.org III-1319-III-1327.
https://arxiv.org/pdf/1302.4389.pdf
"""
struct Maxout{FS<:Tuple}
over::FS

View File

@ -65,9 +65,10 @@ end
"""
AlphaDropout(p)
A dropout layer. It is used in
A dropout layer. Used in
[Self-Normalizing Neural Networks](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf).
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
The AlphaDropout layer ensures that mean and variance of activations
remain the same as before.
Does nothing to the input once [`testmode!`](@ref) is true.
"""
@ -123,8 +124,8 @@ end
initβ = zeros, initγ = ones,
ϵ = 1e-8, momentum = .1)
Batch Normalization layer. The `channels` input should be the size of the
channel dimension in your data (see below).
[Batch Normalization](https://arxiv.org/pdf/1502.03167.pdf) layer.
`channels` should be the size of the channel dimension in your data (see below).
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
a batch of feature vectors this is just the data dimension, for `WHCN` images
@ -136,9 +137,6 @@ per-channel `bias` and `scale` parameters).
Use [`testmode!`](@ref) during inference.
See [Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
# Examples
```julia
m = Chain(
@ -213,37 +211,6 @@ function Base.show(io::IO, l::BatchNorm)
print(io, ")")
end
"""
InstanceNorm(channels::Integer, σ = identity;
initβ = zeros, initγ = ones,
ϵ = 1e-8, momentum = .1)
Instance Normalization layer. The `channels` input should be the size of the
channel dimension in your data (see below).
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
a batch of feature vectors this is just the data dimension, for `WHCN` images
it's the usual channel dimension.)
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
shifts them to have a new mean and variance (corresponding to the learnable,
per-channel `bias` and `scale` parameters).
Use [`testmode!`](@ref) during inference.
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
# Examples
```julia
m = Chain(
Dense(28^2, 64),
InstanceNorm(64, relu),
Dense(64, 10),
InstanceNorm(10),
softmax)
```
"""
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
mutable struct InstanceNorm{F,V,W,N}
@ -258,6 +225,34 @@ mutable struct InstanceNorm{F,V,W,N}
end
# TODO: deprecate in v0.11
"""
InstanceNorm(channels::Integer, σ = identity;
initβ = zeros, initγ = ones,
ϵ = 1e-8, momentum = .1)
[Instance Normalization](https://arxiv.org/abs/1607.08022) layer.
`channels` should be the size of the channel dimension in your data (see below).
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
a batch of feature vectors this is just the data dimension, for `WHCN` images
it's the usual channel dimension.)
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
shifts them to have a new mean and variance (corresponding to the learnable,
per-channel `bias` and `scale` parameters).
Use [`testmode!`](@ref) during inference.
# Examples
```julia
m = Chain(
Dense(28^2, 64),
InstanceNorm(64, relu),
Dense(64, 10),
InstanceNorm(10),
softmax)
```
"""
InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) = InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
InstanceNorm(chs::Integer, λ = identity;

View File

@ -2,7 +2,8 @@
"""
mae(, y)
Return the mean of absolute error `sum(abs.(ŷ .- y)) / length(y)`
Return the mean of absolute error; calculated as
`sum(abs.(ŷ .- y)) / length(y)`.
"""
mae(, y) = sum(abs.( .- y)) * 1 // length(y)
@ -10,8 +11,8 @@ mae(ŷ, y) = sum(abs.(ŷ .- y)) * 1 // length(y)
"""
mse(, y)
Return the mean squared error between and y;
defined as ``\\frac{1}{n} \\sum_{i=1}^n (ŷ_i - y_i)^2``.
Return the mean squared error between and y; calculated as
`sum((ŷ .- y).^2) / length(y)`.
# Examples
```jldoctest
@ -25,10 +26,11 @@ mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
"""
msle(, y; ϵ=eps(eltype()))
Returns the mean of the squared logarithmic errors `sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) / length(y)`.
Return the mean of the squared logarithmic errors; calculated as
`sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) / length(y)`.
The `ϵ` term provides numerical stability.
This error penalizes an under-predicted estimate greater than an over-predicted estimate.
Penalizes an under-predicted estimate greater than an over-predicted estimate.
"""
msle(, y; ϵ=eps(eltype())) = sum((log.( .+ ϵ) .- log.(y .+ ϵ)).^2) * 1 // length(y)
@ -37,13 +39,12 @@ msle(ŷ, y; ϵ=eps(eltype(ŷ))) = sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) *
"""
huber_loss(, y; δ=1.0)
Computes the mean of the Huber loss given the prediction `` and true values `y`. By default, δ is set to 1.0.
Return the mean of the [Huber loss](https://en.wikipedia.org/wiki/Huber_loss)
given the prediction `` and true values `y`.
| 0.5*| - y|, for | - y| <= δ
Hubber loss = |
| δ*(| - y| - 0.5*δ), otherwise
[`Huber Loss`](https://en.wikipedia.org/wiki/Huber_loss).
| 0.5 * | - y|, for | - y| <= δ
Huber loss = |
| δ * (| - y| - 0.5 * δ), otherwise
"""
function huber_loss(, y; δ=eltype()(1))
abs_error = abs.( .- y)
@ -68,7 +69,7 @@ end
crossentropy(, y; weight = nothing)
Return the cross entropy between the given probability distributions;
computed as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`.
calculated as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`.
`weight` can be `Nothing`, a `Number` or an `AbstractVector`.
`weight=nothing` acts like `weight=1` but is faster.
@ -87,7 +88,7 @@ crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _cros
logitcrossentropy(, y; weight = 1)
Return the crossentropy computed after a [`Flux.logsoftmax`](@ref) operation;
computed as `-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)`.
calculated as `-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)`.
`logitcrossentropy(ŷ, y)` is mathematically equivalent to
[`Flux.crossentropy(softmax(log(ŷ)), y)`](@ref) but it is more numerically stable.
@ -184,10 +185,14 @@ end
"""
kldivergence(, y)
KLDivergence is a measure of how much one probability distribution is different from the other.
It is always non-negative and zero only when both the distributions are equal everywhere.
Return the
[Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence)
between the given probability distributions.
[KL Divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence).
KL divergence is a measure of how much one probability distribution is different
from the other.
It is always non-negative and zero only when both the distributions are equal
everywhere.
"""
function kldivergence(, y)
entropy = sum(y .* log.(y)) * 1 //size(y,2)
@ -198,20 +203,20 @@ end
"""
poisson(, y)
Poisson loss function is a measure of how the predicted distribution diverges from the expected distribution.
Returns `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`
Return how much the predicted distribution `` diverges from the expected Poisson
distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`.
[Poisson Loss](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
"""
poisson(, y) = sum( .- y .* log.()) * 1 // size(y,2)
"""
hinge(, y)
Measures the loss given the prediction `` and true labels `y` (containing 1 or -1).
Returns `sum((max.(0, 1 .- ŷ .* y))) / size(y, 2)`
Return the [hinge loss](https://en.wikipedia.org/wiki/Hinge_loss) given the
prediction `` and true labels `y` (containing 1 or -1); calculated as
`sum(max.(0, 1 .- ŷ .* y)) / size(y, 2)`.
[Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss)
See also: [`squared_hinge`](@ref)
"""
hinge(, y) = sum(max.(0, 1 .- .* y)) * 1 // size(y, 2)
@ -219,8 +224,8 @@ hinge(ŷ, y) = sum(max.(0, 1 .- ŷ .* y)) * 1 // size(y, 2)
"""
squared_hinge(, y)
Computes squared hinge loss given the prediction `` and true labels `y` (conatining 1 or -1).
Returns `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`
Return the squared hinge loss given the prediction `` and true labels `y`
(containing 1 or -1); calculated as `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`.
See also: [`hinge`](@ref)
"""
@ -229,28 +234,29 @@ squared_hinge(ŷ, y) = sum((max.(0, 1 .- ŷ .* y)).^2) * 1 // size(y, 2)
"""
dice_coeff_loss(, y; smooth=1)
Loss function used in Image Segmentation. Calculates loss based on dice coefficient. Similar to F1_score.
Returns `1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth)`
[V-Net: Fully Convolutional Neural Networks forVolumetric Medical Image Segmentation](https://arxiv.org/pdf/1606.04797v1.pdf)
Return a loss based on the dice coefficient.
Used in the [V-Net](https://arxiv.org/pdf/1606.04797v1.pdf) image segmentation
architecture.
Similar to the F1_score. Calculated as:
1 - 2*sum(| .* y| + smooth) / (sum(.^2) + sum(y.^2) + smooth)`
"""
dice_coeff_loss(, y; smooth=eltype()(1.0)) = 1 - (2*sum(y .* ) + smooth) / (sum(y.^2) + sum(.^2) + smooth)
"""
tversky_loss(, y; β=0.7)
Used with imbalanced data to give more weightage to False negatives.
Return the [Tversky loss](https://arxiv.org/pdf/1706.05721.pdf).
Used with imbalanced data to give more weight to false negatives.
Larger β weigh recall higher than precision (by placing more emphasis on false negatives)
Returns `1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)`
[Tversky loss function for image segmentation using 3D fully convolutional deep networks](https://arxiv.org/pdf/1706.05721.pdf)
Calculated as:
1 - sum(|y .* | + 1) / (sum(y .* + β*(1 .- y) .* + (1 - β)*y .* (1 .- )) + 1)
"""
tversky_loss(, y; β=eltype()(0.7)) = 1 - (sum(y .* ) + 1) / (sum(y .* + β*(1 .- y) .* + (1 - β)*y .* (1 .- )) + 1)
"""
flatten(x::AbstractArray)
Transforms (w,h,c,b)-shaped input into (w x h x c,b)-shaped output,
Transform (w, h, c, b)-shaped input into (w × h × c, b)-shaped output
by linearizing all values for each element in the batch.
"""
function flatten(x::AbstractArray)

View File

@ -45,8 +45,8 @@ cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.d
"""
onehot(l, labels[, unk])
Create a [`OneHotVector`](@ref) with its `l`-th element `true` based on
possible `labels` set.
Create a [`OneHotVector`](@ref) with its `l`-th element `true` based on the
possible set of `labels`.
If `unk` is given, return `onehot(unk, labels)` if the input label `l` is not found
in `labels`; otherwise it will error.
@ -80,8 +80,10 @@ end
"""
onehotbatch(ls, labels[, unk...])
Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `labels` set, returns the
`onehot(unk, labels)` if given labels `ls` is not found in set `labels`.
Create a [`OneHotMatrix`](@ref) with a batch of labels based on the
possible set of `labels`.
If `unk` is given, return [`onehot(unk, labels)`](@ref) if one of the input
labels `ls` is not found in `labels`; otherwise it will error.
# Examples
```jldoctest

View File

@ -2,6 +2,16 @@ using Juno
import Zygote: Params, gradient
"""
update!(x, )
Update the array `x` according to `x .-= x̄`.
"""
function update!(x::AbstractArray, )
x .-=
end
"""
update!(opt, p, g)
update!(opt, ps::Params, gs)
@ -10,15 +20,7 @@ Perform an update step of the parameters `ps` (or the single parameter `p`)
according to optimizer `opt` and the gradients `gs` (the gradient `g`).
As a result, the parameters are mutated and the optimizer's internal state may change.
update!(x, )
Update the array `x` according to `x .-= x̄`.
"""
function update!(x::AbstractArray, )
x .-=
end
function update!(opt, x, )
x .-= apply!(opt, x, )
end
@ -41,7 +43,7 @@ struct StopException <: Exception end
stop()
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
This would trigger the train loop to stop and exit.
This will trigger the train loop to stop and exit.
# Examples
```julia
@ -57,19 +59,19 @@ end
"""
train!(loss, params, data, opt; cb)
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt`.
For each datapoint `d` in `data` compute the gradient of `loss(d...)` through
backpropagation and call the optimizer `opt`.
In case datapoints `d` are of numeric array type, assumes no splatting is needed
and computes the gradient of `loss(d)`.
In case datapoints `d` are of numeric array type, assume no splatting is needed
and compute the gradient of `loss(d)`.
Takes a callback as keyword argument `cb`. For example, this will print "training"
every 10 seconds (using [`throttle`](@ref)):
A callback is given with the keyword argument `cb`. For example, this will print
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
train!(loss, params, data, opt,
cb = throttle(() -> println("training"), 10))
The callback can call [`Flux.stop()`](@ref) to interrupt the training loop.
The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
"""