Further docstring improvements in src/
Some had to be re-done after the rebase
This commit is contained in:
parent
64ce32ddcf
commit
2ce5f6d9bf
|
@ -183,18 +183,11 @@ outdims(l::Diagonal, isize) = (length(l.α),)
|
||||||
"""
|
"""
|
||||||
Maxout(over)
|
Maxout(over)
|
||||||
|
|
||||||
`Maxout` is a neural network layer which has a number of internal layers
|
The [Maxout](https://arxiv.org/pdf/1302.4389.pdf) layer has a number of
|
||||||
which all receive the same input. The layer returns the elementwise maximium
|
internal layers which all receive the same input. It returns the elementwise
|
||||||
of the internal layers' outputs.
|
maximum of the internal layers' outputs.
|
||||||
|
|
||||||
Maxout over linear dense layers satisfies the univeral approximation theorem.
|
Maxout over linear dense layers satisfies the univeral approximation theorem.
|
||||||
|
|
||||||
Reference:
|
|
||||||
Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, and Yoshua Bengio.
|
|
||||||
2013. Maxout networks.
|
|
||||||
In Proceedings of the 30th International Conference on International Conference on Machine Learning - Volume 28 (ICML'13),
|
|
||||||
Sanjoy Dasgupta and David McAllester (Eds.), Vol. 28. JMLR.org III-1319-III-1327.
|
|
||||||
https://arxiv.org/pdf/1302.4389.pdf
|
|
||||||
"""
|
"""
|
||||||
struct Maxout{FS<:Tuple}
|
struct Maxout{FS<:Tuple}
|
||||||
over::FS
|
over::FS
|
||||||
|
|
|
@ -65,9 +65,10 @@ end
|
||||||
"""
|
"""
|
||||||
AlphaDropout(p)
|
AlphaDropout(p)
|
||||||
|
|
||||||
A dropout layer. It is used in
|
A dropout layer. Used in
|
||||||
[Self-Normalizing Neural Networks](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf).
|
[Self-Normalizing Neural Networks](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf).
|
||||||
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
|
The AlphaDropout layer ensures that mean and variance of activations
|
||||||
|
remain the same as before.
|
||||||
|
|
||||||
Does nothing to the input once [`testmode!`](@ref) is true.
|
Does nothing to the input once [`testmode!`](@ref) is true.
|
||||||
"""
|
"""
|
||||||
|
@ -123,8 +124,8 @@ end
|
||||||
initβ = zeros, initγ = ones,
|
initβ = zeros, initγ = ones,
|
||||||
ϵ = 1e-8, momentum = .1)
|
ϵ = 1e-8, momentum = .1)
|
||||||
|
|
||||||
Batch Normalization layer. The `channels` input should be the size of the
|
[Batch Normalization](https://arxiv.org/pdf/1502.03167.pdf) layer.
|
||||||
channel dimension in your data (see below).
|
`channels` should be the size of the channel dimension in your data (see below).
|
||||||
|
|
||||||
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
|
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
|
||||||
a batch of feature vectors this is just the data dimension, for `WHCN` images
|
a batch of feature vectors this is just the data dimension, for `WHCN` images
|
||||||
|
@ -136,9 +137,6 @@ per-channel `bias` and `scale` parameters).
|
||||||
|
|
||||||
Use [`testmode!`](@ref) during inference.
|
Use [`testmode!`](@ref) during inference.
|
||||||
|
|
||||||
See [Batch Normalization: Accelerating Deep Network Training by Reducing
|
|
||||||
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
|
|
||||||
|
|
||||||
# Examples
|
# Examples
|
||||||
```julia
|
```julia
|
||||||
m = Chain(
|
m = Chain(
|
||||||
|
@ -213,37 +211,6 @@ function Base.show(io::IO, l::BatchNorm)
|
||||||
print(io, ")")
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
InstanceNorm(channels::Integer, σ = identity;
|
|
||||||
initβ = zeros, initγ = ones,
|
|
||||||
ϵ = 1e-8, momentum = .1)
|
|
||||||
|
|
||||||
Instance Normalization layer. The `channels` input should be the size of the
|
|
||||||
channel dimension in your data (see below).
|
|
||||||
|
|
||||||
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
|
|
||||||
a batch of feature vectors this is just the data dimension, for `WHCN` images
|
|
||||||
it's the usual channel dimension.)
|
|
||||||
|
|
||||||
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
|
|
||||||
shifts them to have a new mean and variance (corresponding to the learnable,
|
|
||||||
per-channel `bias` and `scale` parameters).
|
|
||||||
|
|
||||||
Use [`testmode!`](@ref) during inference.
|
|
||||||
|
|
||||||
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
|
|
||||||
|
|
||||||
# Examples
|
|
||||||
```julia
|
|
||||||
m = Chain(
|
|
||||||
Dense(28^2, 64),
|
|
||||||
InstanceNorm(64, relu),
|
|
||||||
Dense(64, 10),
|
|
||||||
InstanceNorm(10),
|
|
||||||
softmax)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||||
|
|
||||||
mutable struct InstanceNorm{F,V,W,N}
|
mutable struct InstanceNorm{F,V,W,N}
|
||||||
|
@ -258,6 +225,34 @@ mutable struct InstanceNorm{F,V,W,N}
|
||||||
end
|
end
|
||||||
|
|
||||||
# TODO: deprecate in v0.11
|
# TODO: deprecate in v0.11
|
||||||
|
"""
|
||||||
|
InstanceNorm(channels::Integer, σ = identity;
|
||||||
|
initβ = zeros, initγ = ones,
|
||||||
|
ϵ = 1e-8, momentum = .1)
|
||||||
|
|
||||||
|
[Instance Normalization](https://arxiv.org/abs/1607.08022) layer.
|
||||||
|
`channels` should be the size of the channel dimension in your data (see below).
|
||||||
|
|
||||||
|
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
|
||||||
|
a batch of feature vectors this is just the data dimension, for `WHCN` images
|
||||||
|
it's the usual channel dimension.)
|
||||||
|
|
||||||
|
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
|
||||||
|
shifts them to have a new mean and variance (corresponding to the learnable,
|
||||||
|
per-channel `bias` and `scale` parameters).
|
||||||
|
|
||||||
|
Use [`testmode!`](@ref) during inference.
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
```julia
|
||||||
|
m = Chain(
|
||||||
|
Dense(28^2, 64),
|
||||||
|
InstanceNorm(64, relu),
|
||||||
|
Dense(64, 10),
|
||||||
|
InstanceNorm(10),
|
||||||
|
softmax)
|
||||||
|
```
|
||||||
|
"""
|
||||||
InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) = InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
|
InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) = InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
|
||||||
|
|
||||||
InstanceNorm(chs::Integer, λ = identity;
|
InstanceNorm(chs::Integer, λ = identity;
|
||||||
|
|
|
@ -2,7 +2,8 @@
|
||||||
"""
|
"""
|
||||||
mae(ŷ, y)
|
mae(ŷ, y)
|
||||||
|
|
||||||
Return the mean of absolute error `sum(abs.(ŷ .- y)) / length(y)`
|
Return the mean of absolute error; calculated as
|
||||||
|
`sum(abs.(ŷ .- y)) / length(y)`.
|
||||||
"""
|
"""
|
||||||
mae(ŷ, y) = sum(abs.(ŷ .- y)) * 1 // length(y)
|
mae(ŷ, y) = sum(abs.(ŷ .- y)) * 1 // length(y)
|
||||||
|
|
||||||
|
@ -10,8 +11,8 @@ mae(ŷ, y) = sum(abs.(ŷ .- y)) * 1 // length(y)
|
||||||
"""
|
"""
|
||||||
mse(ŷ, y)
|
mse(ŷ, y)
|
||||||
|
|
||||||
Return the mean squared error between ŷ and y;
|
Return the mean squared error between ŷ and y; calculated as
|
||||||
defined as ``\\frac{1}{n} \\sum_{i=1}^n (ŷ_i - y_i)^2``.
|
`sum((ŷ .- y).^2) / length(y)`.
|
||||||
|
|
||||||
# Examples
|
# Examples
|
||||||
```jldoctest
|
```jldoctest
|
||||||
|
@ -25,10 +26,11 @@ mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
|
||||||
"""
|
"""
|
||||||
msle(ŷ, y; ϵ=eps(eltype(ŷ)))
|
msle(ŷ, y; ϵ=eps(eltype(ŷ)))
|
||||||
|
|
||||||
Returns the mean of the squared logarithmic errors `sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) / length(y)`.
|
Return the mean of the squared logarithmic errors; calculated as
|
||||||
|
`sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) / length(y)`.
|
||||||
The `ϵ` term provides numerical stability.
|
The `ϵ` term provides numerical stability.
|
||||||
|
|
||||||
This error penalizes an under-predicted estimate greater than an over-predicted estimate.
|
Penalizes an under-predicted estimate greater than an over-predicted estimate.
|
||||||
"""
|
"""
|
||||||
msle(ŷ, y; ϵ=eps(eltype(ŷ))) = sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) * 1 // length(y)
|
msle(ŷ, y; ϵ=eps(eltype(ŷ))) = sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) * 1 // length(y)
|
||||||
|
|
||||||
|
@ -37,13 +39,12 @@ msle(ŷ, y; ϵ=eps(eltype(ŷ))) = sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) *
|
||||||
"""
|
"""
|
||||||
huber_loss(ŷ, y; δ=1.0)
|
huber_loss(ŷ, y; δ=1.0)
|
||||||
|
|
||||||
Computes the mean of the Huber loss given the prediction `ŷ` and true values `y`. By default, δ is set to 1.0.
|
Return the mean of the [Huber loss](https://en.wikipedia.org/wiki/Huber_loss)
|
||||||
|
given the prediction `ŷ` and true values `y`.
|
||||||
|
|
||||||
| 0.5*|ŷ - y|, for |ŷ - y| <= δ
|
| 0.5 * |ŷ - y|, for |ŷ - y| <= δ
|
||||||
Hubber loss = |
|
Huber loss = |
|
||||||
| δ*(|ŷ - y| - 0.5*δ), otherwise
|
| δ * (|ŷ - y| - 0.5 * δ), otherwise
|
||||||
|
|
||||||
[`Huber Loss`](https://en.wikipedia.org/wiki/Huber_loss).
|
|
||||||
"""
|
"""
|
||||||
function huber_loss(ŷ, y; δ=eltype(ŷ)(1))
|
function huber_loss(ŷ, y; δ=eltype(ŷ)(1))
|
||||||
abs_error = abs.(ŷ .- y)
|
abs_error = abs.(ŷ .- y)
|
||||||
|
@ -68,7 +69,7 @@ end
|
||||||
crossentropy(ŷ, y; weight = nothing)
|
crossentropy(ŷ, y; weight = nothing)
|
||||||
|
|
||||||
Return the cross entropy between the given probability distributions;
|
Return the cross entropy between the given probability distributions;
|
||||||
computed as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`.
|
calculated as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`.
|
||||||
|
|
||||||
`weight` can be `Nothing`, a `Number` or an `AbstractVector`.
|
`weight` can be `Nothing`, a `Number` or an `AbstractVector`.
|
||||||
`weight=nothing` acts like `weight=1` but is faster.
|
`weight=nothing` acts like `weight=1` but is faster.
|
||||||
|
@ -87,7 +88,7 @@ crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _cros
|
||||||
logitcrossentropy(ŷ, y; weight = 1)
|
logitcrossentropy(ŷ, y; weight = 1)
|
||||||
|
|
||||||
Return the crossentropy computed after a [`Flux.logsoftmax`](@ref) operation;
|
Return the crossentropy computed after a [`Flux.logsoftmax`](@ref) operation;
|
||||||
computed as `-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)`.
|
calculated as `-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)`.
|
||||||
|
|
||||||
`logitcrossentropy(ŷ, y)` is mathematically equivalent to
|
`logitcrossentropy(ŷ, y)` is mathematically equivalent to
|
||||||
[`Flux.crossentropy(softmax(log(ŷ)), y)`](@ref) but it is more numerically stable.
|
[`Flux.crossentropy(softmax(log(ŷ)), y)`](@ref) but it is more numerically stable.
|
||||||
|
@ -184,10 +185,14 @@ end
|
||||||
"""
|
"""
|
||||||
kldivergence(ŷ, y)
|
kldivergence(ŷ, y)
|
||||||
|
|
||||||
KLDivergence is a measure of how much one probability distribution is different from the other.
|
Return the
|
||||||
It is always non-negative and zero only when both the distributions are equal everywhere.
|
[Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence)
|
||||||
|
between the given probability distributions.
|
||||||
|
|
||||||
[KL Divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence).
|
KL divergence is a measure of how much one probability distribution is different
|
||||||
|
from the other.
|
||||||
|
It is always non-negative and zero only when both the distributions are equal
|
||||||
|
everywhere.
|
||||||
"""
|
"""
|
||||||
function kldivergence(ŷ, y)
|
function kldivergence(ŷ, y)
|
||||||
entropy = sum(y .* log.(y)) * 1 //size(y,2)
|
entropy = sum(y .* log.(y)) * 1 //size(y,2)
|
||||||
|
@ -198,20 +203,20 @@ end
|
||||||
"""
|
"""
|
||||||
poisson(ŷ, y)
|
poisson(ŷ, y)
|
||||||
|
|
||||||
Poisson loss function is a measure of how the predicted distribution diverges from the expected distribution.
|
Return how much the predicted distribution `ŷ` diverges from the expected Poisson
|
||||||
Returns `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`
|
distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`.
|
||||||
|
|
||||||
[Poisson Loss](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
|
[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
|
||||||
"""
|
"""
|
||||||
poisson(ŷ, y) = sum(ŷ .- y .* log.(ŷ)) * 1 // size(y,2)
|
poisson(ŷ, y) = sum(ŷ .- y .* log.(ŷ)) * 1 // size(y,2)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
hinge(ŷ, y)
|
hinge(ŷ, y)
|
||||||
|
|
||||||
Measures the loss given the prediction `ŷ` and true labels `y` (containing 1 or -1).
|
Return the [hinge loss](https://en.wikipedia.org/wiki/Hinge_loss) given the
|
||||||
Returns `sum((max.(0, 1 .- ŷ .* y))) / size(y, 2)`
|
prediction `ŷ` and true labels `y` (containing 1 or -1); calculated as
|
||||||
|
`sum(max.(0, 1 .- ŷ .* y)) / size(y, 2)`.
|
||||||
|
|
||||||
[Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss)
|
|
||||||
See also: [`squared_hinge`](@ref)
|
See also: [`squared_hinge`](@ref)
|
||||||
"""
|
"""
|
||||||
hinge(ŷ, y) = sum(max.(0, 1 .- ŷ .* y)) * 1 // size(y, 2)
|
hinge(ŷ, y) = sum(max.(0, 1 .- ŷ .* y)) * 1 // size(y, 2)
|
||||||
|
@ -219,8 +224,8 @@ hinge(ŷ, y) = sum(max.(0, 1 .- ŷ .* y)) * 1 // size(y, 2)
|
||||||
"""
|
"""
|
||||||
squared_hinge(ŷ, y)
|
squared_hinge(ŷ, y)
|
||||||
|
|
||||||
Computes squared hinge loss given the prediction `ŷ` and true labels `y` (conatining 1 or -1).
|
Return the squared hinge loss given the prediction `ŷ` and true labels `y`
|
||||||
Returns `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`
|
(containing 1 or -1); calculated as `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`.
|
||||||
|
|
||||||
See also: [`hinge`](@ref)
|
See also: [`hinge`](@ref)
|
||||||
"""
|
"""
|
||||||
|
@ -229,28 +234,29 @@ squared_hinge(ŷ, y) = sum((max.(0, 1 .- ŷ .* y)).^2) * 1 // size(y, 2)
|
||||||
"""
|
"""
|
||||||
dice_coeff_loss(ŷ, y; smooth=1)
|
dice_coeff_loss(ŷ, y; smooth=1)
|
||||||
|
|
||||||
Loss function used in Image Segmentation. Calculates loss based on dice coefficient. Similar to F1_score.
|
Return a loss based on the dice coefficient.
|
||||||
Returns `1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth)`
|
Used in the [V-Net](https://arxiv.org/pdf/1606.04797v1.pdf) image segmentation
|
||||||
|
architecture.
|
||||||
[V-Net: Fully Convolutional Neural Networks forVolumetric Medical Image Segmentation](https://arxiv.org/pdf/1606.04797v1.pdf)
|
Similar to the F1_score. Calculated as:
|
||||||
|
1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth)`
|
||||||
"""
|
"""
|
||||||
dice_coeff_loss(ŷ, y; smooth=eltype(ŷ)(1.0)) = 1 - (2*sum(y .* ŷ) + smooth) / (sum(y.^2) + sum(ŷ.^2) + smooth)
|
dice_coeff_loss(ŷ, y; smooth=eltype(ŷ)(1.0)) = 1 - (2*sum(y .* ŷ) + smooth) / (sum(y.^2) + sum(ŷ.^2) + smooth)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
tversky_loss(ŷ, y; β=0.7)
|
tversky_loss(ŷ, y; β=0.7)
|
||||||
|
|
||||||
Used with imbalanced data to give more weightage to False negatives.
|
Return the [Tversky loss](https://arxiv.org/pdf/1706.05721.pdf).
|
||||||
|
Used with imbalanced data to give more weight to false negatives.
|
||||||
Larger β weigh recall higher than precision (by placing more emphasis on false negatives)
|
Larger β weigh recall higher than precision (by placing more emphasis on false negatives)
|
||||||
Returns `1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)`
|
Calculated as:
|
||||||
|
1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
|
||||||
[Tversky loss function for image segmentation using 3D fully convolutional deep networks](https://arxiv.org/pdf/1706.05721.pdf)
|
|
||||||
"""
|
"""
|
||||||
tversky_loss(ŷ, y; β=eltype(ŷ)(0.7)) = 1 - (sum(y .* ŷ) + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
|
tversky_loss(ŷ, y; β=eltype(ŷ)(0.7)) = 1 - (sum(y .* ŷ) + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
flatten(x::AbstractArray)
|
flatten(x::AbstractArray)
|
||||||
|
|
||||||
Transforms (w,h,c,b)-shaped input into (w x h x c,b)-shaped output,
|
Transform (w, h, c, b)-shaped input into (w × h × c, b)-shaped output
|
||||||
by linearizing all values for each element in the batch.
|
by linearizing all values for each element in the batch.
|
||||||
"""
|
"""
|
||||||
function flatten(x::AbstractArray)
|
function flatten(x::AbstractArray)
|
||||||
|
|
|
@ -45,8 +45,8 @@ cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.d
|
||||||
"""
|
"""
|
||||||
onehot(l, labels[, unk])
|
onehot(l, labels[, unk])
|
||||||
|
|
||||||
Create a [`OneHotVector`](@ref) with its `l`-th element `true` based on
|
Create a [`OneHotVector`](@ref) with its `l`-th element `true` based on the
|
||||||
possible `labels` set.
|
possible set of `labels`.
|
||||||
If `unk` is given, return `onehot(unk, labels)` if the input label `l` is not found
|
If `unk` is given, return `onehot(unk, labels)` if the input label `l` is not found
|
||||||
in `labels`; otherwise it will error.
|
in `labels`; otherwise it will error.
|
||||||
|
|
||||||
|
@ -80,8 +80,10 @@ end
|
||||||
"""
|
"""
|
||||||
onehotbatch(ls, labels[, unk...])
|
onehotbatch(ls, labels[, unk...])
|
||||||
|
|
||||||
Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `labels` set, returns the
|
Create a [`OneHotMatrix`](@ref) with a batch of labels based on the
|
||||||
`onehot(unk, labels)` if given labels `ls` is not found in set `labels`.
|
possible set of `labels`.
|
||||||
|
If `unk` is given, return [`onehot(unk, labels)`](@ref) if one of the input
|
||||||
|
labels `ls` is not found in `labels`; otherwise it will error.
|
||||||
|
|
||||||
# Examples
|
# Examples
|
||||||
```jldoctest
|
```jldoctest
|
||||||
|
|
|
@ -2,23 +2,25 @@ using Juno
|
||||||
import Zygote: Params, gradient
|
import Zygote: Params, gradient
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
update!(opt, p, g)
|
|
||||||
update!(opt, ps::Params, gs)
|
|
||||||
|
|
||||||
Perform an update step of the parameters `ps` (or the single parameter `p`)
|
|
||||||
according to optimizer `opt` and the gradients `gs` (the gradient `g`).
|
|
||||||
|
|
||||||
As a result, the parameters are mutated and the optimizer's internal state may change.
|
|
||||||
|
|
||||||
update!(x, x̄)
|
update!(x, x̄)
|
||||||
|
|
||||||
Update the array `x` according to `x .-= x̄`.
|
Update the array `x` according to `x .-= x̄`.
|
||||||
"""
|
"""
|
||||||
function update!(x::AbstractArray, x̄)
|
function update!(x::AbstractArray, x̄)
|
||||||
x .-= x̄
|
x .-= x̄
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
update!(opt, p, g)
|
||||||
|
update!(opt, ps::Params, gs)
|
||||||
|
|
||||||
|
Perform an update step of the parameters `ps` (or the single parameter `p`)
|
||||||
|
according to optimizer `opt` and the gradients `gs` (the gradient `g`).
|
||||||
|
|
||||||
|
As a result, the parameters are mutated and the optimizer's internal state may change.
|
||||||
|
"""
|
||||||
function update!(opt, x, x̄)
|
function update!(opt, x, x̄)
|
||||||
x .-= apply!(opt, x, x̄)
|
x .-= apply!(opt, x, x̄)
|
||||||
end
|
end
|
||||||
|
@ -41,7 +43,7 @@ struct StopException <: Exception end
|
||||||
stop()
|
stop()
|
||||||
|
|
||||||
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
|
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
|
||||||
This would trigger the train loop to stop and exit.
|
This will trigger the train loop to stop and exit.
|
||||||
|
|
||||||
# Examples
|
# Examples
|
||||||
```julia
|
```julia
|
||||||
|
@ -57,19 +59,19 @@ end
|
||||||
"""
|
"""
|
||||||
train!(loss, params, data, opt; cb)
|
train!(loss, params, data, opt; cb)
|
||||||
|
|
||||||
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
|
For each datapoint `d` in `data` compute the gradient of `loss(d...)` through
|
||||||
backpropagation and calls the optimizer `opt`.
|
backpropagation and call the optimizer `opt`.
|
||||||
|
|
||||||
In case datapoints `d` are of numeric array type, assumes no splatting is needed
|
In case datapoints `d` are of numeric array type, assume no splatting is needed
|
||||||
and computes the gradient of `loss(d)`.
|
and compute the gradient of `loss(d)`.
|
||||||
|
|
||||||
Takes a callback as keyword argument `cb`. For example, this will print "training"
|
A callback is given with the keyword argument `cb`. For example, this will print
|
||||||
every 10 seconds (using [`throttle`](@ref)):
|
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
|
||||||
|
|
||||||
train!(loss, params, data, opt,
|
train!(loss, params, data, opt,
|
||||||
cb = throttle(() -> println("training"), 10))
|
cb = throttle(() -> println("training"), 10))
|
||||||
|
|
||||||
The callback can call [`Flux.stop()`](@ref) to interrupt the training loop.
|
The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
|
||||||
|
|
||||||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue