Improve docstrings
Improvements like... - fixing typos, - removing trailing and double whitespaces, - using `jldoctest` blocks where applicable, - fixing, updating or correctly setting up existing doctests, - improving consistency (for example, always use "# Examples" instead of other variants), - removing empty lines between docstrings and functions, - instead of mentioning keywords, put them into the docstring, - adding some missing but useful keywords, - adding references (`@ref`), - using LaTeX math where applicable, and - linking papers. Debatable stuff that is untouched: - BE/AE s/z irregularities ("normalise" versus "normalize") since most papers use the AE version while the Flux source code was written with BE spelling. - Names of normalization functions are capitalized ("Batch Normalization" instead of "batch normalization").
This commit is contained in:
parent
c76b7315ac
commit
ab86e350f2
|
@ -33,9 +33,10 @@ const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte")
|
|||
|
||||
Load the Fashion-MNIST images.
|
||||
|
||||
Each image is a 28×28 array of `Gray` colour values (see Colors.jl).
|
||||
Each image is a 28×28 array of `Gray` colour values
|
||||
(see [Colors.jl](https://github.com/JuliaGraphics/Colors.jl)).
|
||||
|
||||
Returns the 60,000 training images by default; pass `:test` to retreive the
|
||||
Return the 60,000 training images by default; pass `:test` to retrieve the
|
||||
10,000 test images.
|
||||
"""
|
||||
function images(set = :train)
|
||||
|
@ -49,10 +50,10 @@ end
|
|||
labels()
|
||||
labels(:test)
|
||||
|
||||
Load the labels corresponding to each of the images returned from `images()`.
|
||||
Load the labels corresponding to each of the images returned from [`images()`](@ref).
|
||||
Each label is a number from 0-9.
|
||||
|
||||
Returns the 60,000 training labels by default; pass `:test` to retreive the
|
||||
Return the 60,000 training labels by default; pass `:test` to retrieve the
|
||||
10,000 test labels.
|
||||
"""
|
||||
function labels(set = :train)
|
||||
|
|
|
@ -4,11 +4,10 @@ Fisher's classic iris dataset.
|
|||
Measurements from 3 different species of iris: setosa, versicolor and
|
||||
virginica. There are 50 examples of each species.
|
||||
|
||||
There are 4 measurements for each example: sepal length, sepal width, petal
|
||||
length and petal width. The measurements are in centimeters.
|
||||
There are 4 measurements for each example: sepal length, sepal width,
|
||||
petal length and petal width. The measurements are in centimeters.
|
||||
|
||||
The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris).
|
||||
|
||||
"""
|
||||
module Iris
|
||||
|
||||
|
@ -33,9 +32,7 @@ end
|
|||
Get the labels of the iris dataset, a 150 element array of strings listing the
|
||||
species of each example.
|
||||
|
||||
```jldoctest
|
||||
julia> using Flux
|
||||
|
||||
```jldoctest; setup = :(Flux.Data.Iris.load())
|
||||
julia> labels = Flux.Data.Iris.labels();
|
||||
|
||||
julia> summary(labels)
|
||||
|
@ -58,9 +55,7 @@ Get the features of the iris dataset. This is a 4x150 matrix of Float64
|
|||
elements. It has a row for each feature (sepal length, sepal width,
|
||||
petal length, petal width) and a column for each example.
|
||||
|
||||
```jldoctest
|
||||
julia> using Flux
|
||||
|
||||
```jldoctest; setup = :(Flux.Data.Iris.load())
|
||||
julia> features = Flux.Data.Iris.features();
|
||||
|
||||
julia> summary(features)
|
||||
|
|
|
@ -83,9 +83,10 @@ getfeatures(io::IO, index::Integer) = vec(getimage(io, index))
|
|||
|
||||
Load the MNIST images.
|
||||
|
||||
Each image is a 28×28 array of `Gray` colour values (see Colors.jl).
|
||||
Each image is a 28×28 array of `Gray` colour values
|
||||
(see [Colors.jl](https://github.com/JuliaGraphics/Colors.jl)).
|
||||
|
||||
Returns the 60,000 training images by default; pass `:test` to retreive the
|
||||
Return the 60,000 training images by default; pass `:test` to retrieve the
|
||||
10,000 test images.
|
||||
"""
|
||||
function images(set = :train)
|
||||
|
@ -99,10 +100,10 @@ end
|
|||
labels()
|
||||
labels(:test)
|
||||
|
||||
Load the labels corresponding to each of the images returned from `images()`.
|
||||
Load the labels corresponding to each of the images returned from [`images()`](@ref).
|
||||
Each label is a number from 0-9.
|
||||
|
||||
Returns the 60,000 training labels by default; pass `:test` to retreive the
|
||||
Return the 60,000 training labels by default; pass `:test` to retrieve the
|
||||
10,000 test labels.
|
||||
"""
|
||||
function labels(set = :train)
|
||||
|
|
|
@ -4,17 +4,23 @@
|
|||
Chain multiple layers / functions together, so that they are called in sequence
|
||||
on a given input.
|
||||
|
||||
```julia
|
||||
m = Chain(x -> x^2, x -> x+1)
|
||||
m(5) == 26
|
||||
|
||||
m = Chain(Dense(10, 5), Dense(5, 2))
|
||||
x = rand(10)
|
||||
m(x) == m[2](m[1](x))
|
||||
```
|
||||
|
||||
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
|
||||
`m[1:3](x)` will calculate the output of the first three layers.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> m = Chain(x -> x^2, x -> x+1);
|
||||
|
||||
julia> m(5) == 26
|
||||
true
|
||||
|
||||
julia> m = Chain(Dense(10, 5), Dense(5, 2));
|
||||
|
||||
julia> x = rand(10);
|
||||
|
||||
julia> m(x) == m[2](m[1](x))
|
||||
true
|
||||
```
|
||||
"""
|
||||
struct Chain{T<:Tuple}
|
||||
layers::T
|
||||
|
@ -60,6 +66,7 @@ outdims(c::Chain, isize) = foldl(∘, map(l -> (x -> outdims(l, x)), c.layers))(
|
|||
# only slightly changed to better handle interaction with Zygote @dsweber2
|
||||
"""
|
||||
activations(c::Chain, input)
|
||||
|
||||
Calculate the forward results of each layers in Chain `c` with `input` as model input.
|
||||
"""
|
||||
function activations(c::Chain, input)
|
||||
|
@ -78,14 +85,15 @@ extraChain(::Tuple{}, x) = ()
|
|||
"""
|
||||
Dense(in::Integer, out::Integer, σ = identity)
|
||||
|
||||
Creates a traditional `Dense` layer with parameters `W` and `b`.
|
||||
Create a traditional `Dense` layer with parameters `W` and `b`.
|
||||
|
||||
y = σ.(W * x .+ b)
|
||||
|
||||
The input `x` must be a vector of length `in`, or a batch of vectors represented
|
||||
as an `in × N` matrix. The out `y` will be a vector or batch of length `out`.
|
||||
|
||||
```julia
|
||||
# Examples
|
||||
```jldoctest; setup = :(using Random; Random.seed!(0))
|
||||
julia> d = Dense(5, 2)
|
||||
Dense(5, 2)
|
||||
|
||||
|
@ -145,7 +153,7 @@ outdims(l::Dense, isize) = (size(l.W)[1],)
|
|||
"""
|
||||
Diagonal(in::Integer)
|
||||
|
||||
Creates an element-wise linear transformation layer with learnable
|
||||
Create an element-wise linear transformation layer with learnable
|
||||
vectors `α` and `β`:
|
||||
|
||||
y = α .* x .+ β
|
||||
|
@ -176,8 +184,8 @@ outdims(l::Diagonal, isize) = (length(l.α),)
|
|||
"""
|
||||
Maxout(over)
|
||||
|
||||
`Maxout` is a neural network layer, which has a number of internal layers,
|
||||
which all have the same input, and the maxout returns the elementwise maximium
|
||||
`Maxout` is a neural network layer which has a number of internal layers
|
||||
which all receive the same input. The layer returns the elementwise maximium
|
||||
of the internal layers' outputs.
|
||||
|
||||
Maxout over linear dense layers satisfies the univeral approximation theorem.
|
||||
|
@ -196,13 +204,14 @@ end
|
|||
"""
|
||||
Maxout(f, n_alts)
|
||||
|
||||
Constructs a Maxout layer over `n_alts` instances of the layer given by `f`.
|
||||
The function takes no arguement and should return some callable layer.
|
||||
Conventionally this is a linear dense layer.
|
||||
Construct a Maxout layer over `n_alts` instances of the layer given by `f`.
|
||||
The function takes no arguments and should return some callable layer.
|
||||
Conventionally, this is a linear dense layer.
|
||||
|
||||
For example the following example which
|
||||
will construct a `Maxout` layer over 4 internal dense linear layers,
|
||||
each identical in structure (784 inputs, 128 outputs).
|
||||
# Examples
|
||||
|
||||
This constructs a `Maxout` layer over 4 internal dense linear layers, each
|
||||
identical in structure (784 inputs, 128 outputs):
|
||||
```julia
|
||||
insize = 784
|
||||
outsize = 128
|
||||
|
@ -223,16 +232,18 @@ end
|
|||
outdims(l::Maxout, isize) = outdims(first(l.over), isize)
|
||||
|
||||
"""
|
||||
SkipConnection(layers, connection)
|
||||
SkipConnection(layer, connection)
|
||||
|
||||
Creates a Skip Connection, of a layer or `Chain` of consecutive layers
|
||||
plus a shortcut connection. The connection function will combine the result of the layers
|
||||
with the original input, to give the final output.
|
||||
Create a skip connection which consists of a layer or `Chain` of consecutive
|
||||
layers and a shortcut connection linking the block's input to the output
|
||||
through a user-supplied 2-argument callable. The first argument to the callable
|
||||
will be propagated through the given `layer` while the second is the unchanged,
|
||||
"skipped" input.
|
||||
|
||||
The simplest 'ResNet'-type connection is just `SkipConnection(layer, +)`,
|
||||
The simplest "ResNet"-type connection is just `SkipConnection(layer, +)`,
|
||||
and requires the output of the layers to be the same shape as the input.
|
||||
Here is a more complicated example:
|
||||
```
|
||||
```julia
|
||||
m = Conv((3,3), 4=>7, pad=(1,1))
|
||||
x = ones(5,5,4,10);
|
||||
size(m(x)) == (5, 5, 7, 10)
|
||||
|
|
|
@ -8,25 +8,26 @@ _convtransoutdims(isize, ksize, ssize, dsize, pad) = (isize .- 1).*ssize .+ 1 .+
|
|||
expand(N, i::Tuple) = i
|
||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||
"""
|
||||
Conv(size, in=>out)
|
||||
Conv(size, in=>out, relu)
|
||||
Conv(size, in => out, σ = identity; init = glorot_uniform,
|
||||
stride = 1, pad = 0, dilation = 1)
|
||||
|
||||
Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||
`in` and `out` specify the number of input and output channels respectively.
|
||||
|
||||
Example: Applying Conv layer to a 1-channel input using a 2x2 window size,
|
||||
giving us a 16-channel output. Output is activated with ReLU.
|
||||
|
||||
size = (2,2)
|
||||
in = 1
|
||||
out = 16
|
||||
Conv((2, 2), 1=>16, relu)
|
||||
|
||||
Data should be stored in WHCN order (width, height, # channels, batch size).
|
||||
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||
and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
# Examples
|
||||
|
||||
Apply a `Conv` layer to a 1-channel input using a 2×2 window size, giving us a
|
||||
16-channel output. Output is activated with ReLU.
|
||||
```julia
|
||||
size = (2,2)
|
||||
in = 1
|
||||
out = 16
|
||||
Conv(size, in => out, relu)
|
||||
```
|
||||
"""
|
||||
struct Conv{N,M,F,A,V}
|
||||
σ::F
|
||||
|
@ -76,8 +77,8 @@ end
|
|||
"""
|
||||
outdims(l::Conv, isize::Tuple)
|
||||
|
||||
Calculate the output dimensions given the input dimensions, `isize`.
|
||||
Batch size and channel size are ignored as per `NNlib.jl`.
|
||||
Calculate the output dimensions given the input dimensions `isize`.
|
||||
Batch size and channel size are ignored as per [NNlib.jl](https://github.com/FluxML/NNlib.jl).
|
||||
|
||||
```julia
|
||||
m = Conv((3, 3), 3 => 16)
|
||||
|
@ -89,17 +90,15 @@ outdims(l::Conv, isize) =
|
|||
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
|
||||
|
||||
"""
|
||||
ConvTranspose(size, in=>out)
|
||||
ConvTranspose(size, in=>out, relu)
|
||||
ConvTranspose(size, in => out, σ = identity; init = glorot_uniform,
|
||||
stride = 1, pad = 0, dilation = 1)
|
||||
|
||||
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
|
||||
`in` and `out` specify the number of input and output channels respectively.
|
||||
|
||||
Data should be stored in WHCN order (width, height, # channels, # batches).
|
||||
Data should be stored in WHCN order (width, height, # channels, batch size).
|
||||
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||
and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
"""
|
||||
struct ConvTranspose{N,M,F,A,V}
|
||||
σ::F
|
||||
|
@ -165,18 +164,16 @@ end
|
|||
outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad)
|
||||
|
||||
"""
|
||||
DepthwiseConv(size, in=>out)
|
||||
DepthwiseConv(size, in=>out, relu)
|
||||
DepthwiseConv(size, in => out, σ = identity; init = glorot_uniform,
|
||||
stride = 1, pad = 0, dilation = 1)
|
||||
|
||||
Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||
`in` and `out` specify the number of input and output channels respectively.
|
||||
Note that `out` must be an integer multiple of `in`.
|
||||
|
||||
Data should be stored in WHCN order (width, height, # channels, # batches).
|
||||
Data should be stored in WHCN order (width, height, # channels, batch size).
|
||||
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||
and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
"""
|
||||
struct DepthwiseConv{N,M,F,A,V}
|
||||
σ::F
|
||||
|
@ -233,25 +230,26 @@ outdims(l::DepthwiseConv, isize) =
|
|||
output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
|
||||
|
||||
"""
|
||||
CrossCor(size, in=>out)
|
||||
CrossCor(size, in=>out, relu)
|
||||
CrossCor(size, in => out, σ = identity; init = glorot_uniform,
|
||||
stride = 1, pad = 0, dilation = 1)
|
||||
|
||||
Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||
`in` and `out` specify the number of input and output channels respectively.
|
||||
|
||||
Example: Applying CrossCor layer to a 1-channel input using a 2x2 window size,
|
||||
giving us a 16-channel output. Output is activated with ReLU.
|
||||
Data should be stored in WHCN order (width, height, # channels, batch size).
|
||||
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||
and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
# Examples
|
||||
|
||||
Apply a `CrossCor` layer to a 1-channel input using a 2×2 window size, giving us a
|
||||
16-channel output. Output is activated with ReLU.
|
||||
```julia
|
||||
size = (2,2)
|
||||
in = 1
|
||||
out = 16
|
||||
CrossCor((2, 2), 1=>16, relu)
|
||||
|
||||
Data should be stored in WHCN order (width, height, # channels, # batches).
|
||||
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||
and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
```
|
||||
"""
|
||||
struct CrossCor{N,M,F,A,V}
|
||||
σ::F
|
||||
|
@ -357,11 +355,9 @@ function Base.show(io::IO, g::GlobalMeanPool)
|
|||
end
|
||||
|
||||
"""
|
||||
MaxPool(k)
|
||||
MaxPool(k; pad = 0, stride = k)
|
||||
|
||||
Max pooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
Max pooling layer. `k` is the size of the window for each dimension of the input.
|
||||
"""
|
||||
struct MaxPool{N,M}
|
||||
k::NTuple{N,Int}
|
||||
|
@ -388,11 +384,9 @@ end
|
|||
outdims(l::MaxPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
|
||||
|
||||
"""
|
||||
MeanPool(k)
|
||||
MeanPool(k; pad = 0, stride = k)
|
||||
|
||||
Mean pooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
Mean pooling layer. `k` is the size of the window for each dimension of the input.
|
||||
"""
|
||||
struct MeanPool{N,M}
|
||||
k::NTuple{N,Int}
|
||||
|
|
|
@ -10,14 +10,14 @@ _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(s
|
|||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||
|
||||
"""
|
||||
dropout(p, dims = :)
|
||||
dropout(x, p; dims = :)
|
||||
|
||||
Dropout function. For each input, either sets that input to `0` (with probability
|
||||
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specify the unbroadcasted
|
||||
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
|
||||
used as a regularisation, i.e. it reduces overfitting during training.
|
||||
The dropout function. For each input, either sets that input to `0` (with probability
|
||||
`p`) or scales it by `1 / (1 - p)`. `dims` specifies the unbroadcasted dimensions,
|
||||
e.g. `dims=1` applies dropout along columns and `dims=2` along rows.
|
||||
This is used as a regularisation, i.e. it reduces overfitting during training.
|
||||
|
||||
See also [`Dropout`](@ref).
|
||||
See also the [`Dropout`](@ref) layer.
|
||||
"""
|
||||
dropout(x, p; dims = :) = x
|
||||
|
||||
|
@ -32,7 +32,7 @@ end
|
|||
|
||||
A Dropout layer. In the forward pass, applies the [`dropout`](@ref) function on the input.
|
||||
|
||||
Does nothing to the input once [`testmode!`](@ref) is true.
|
||||
Does nothing to the input once [`Flux.testmode!`](@ref) is `true`.
|
||||
"""
|
||||
mutable struct Dropout{F,D}
|
||||
p::F
|
||||
|
@ -65,8 +65,8 @@ end
|
|||
"""
|
||||
AlphaDropout(p)
|
||||
|
||||
A dropout layer. It is used in Self-Normalizing Neural Networks.
|
||||
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
|
||||
A dropout layer. It is used in
|
||||
[Self-Normalizing Neural Networks](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf).
|
||||
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
|
||||
|
||||
Does nothing to the input once [`testmode!`](@ref) is true.
|
||||
|
@ -100,8 +100,8 @@ testmode!(m::AlphaDropout, mode = true) =
|
|||
LayerNorm(h::Integer)
|
||||
|
||||
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
|
||||
used with recurrent hidden states of size `h`. Normalises the mean/stddev of
|
||||
each input before applying a per-neuron gain/bias.
|
||||
used with recurrent hidden states of size `h`. Normalises the mean and standard
|
||||
deviation of each input before applying a per-neuron gain/bias.
|
||||
"""
|
||||
struct LayerNorm{T}
|
||||
diag::Diagonal{T}
|
||||
|
@ -139,7 +139,7 @@ Use [`testmode!`](@ref) during inference.
|
|||
See [Batch Normalization: Accelerating Deep Network Training by Reducing
|
||||
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
|
||||
|
||||
Example:
|
||||
# Examples
|
||||
```julia
|
||||
m = Chain(
|
||||
Dense(28^2, 64),
|
||||
|
@ -234,7 +234,7 @@ Use [`testmode!`](@ref) during inference.
|
|||
|
||||
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
|
||||
|
||||
Example:
|
||||
# Examples
|
||||
```julia
|
||||
m = Chain(
|
||||
Dense(28^2, 64),
|
||||
|
@ -316,28 +316,27 @@ function Base.show(io::IO, l::InstanceNorm)
|
|||
end
|
||||
|
||||
"""
|
||||
Group Normalization.
|
||||
This layer can outperform Batch-Normalization and Instance-Normalization.
|
||||
|
||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
|
||||
ϵ = 1f-5, momentum = 0.1f0)
|
||||
|
||||
``chs`` is the number of channels, the channel dimension of your input.
|
||||
For an array of N dimensions, the (N-1)th index is the channel dimension.
|
||||
[Group Normalization](https://arxiv.org/pdf/1803.08494.pdf) layer.
|
||||
This layer can outperform Batch Normalization and Instance Normalization.
|
||||
|
||||
``G`` is the number of groups along which the statistics would be computed.
|
||||
`chs` is the number of channels, the channel dimension of your input.
|
||||
For an array of N dimensions, the `N-1`th index is the channel dimension.
|
||||
|
||||
`G` is the number of groups along which the statistics are computed.
|
||||
The number of channels must be an integer multiple of the number of groups.
|
||||
|
||||
Use [`testmode!`](@ref) during inference.
|
||||
|
||||
Example:
|
||||
```
|
||||
# Examples
|
||||
```julia
|
||||
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
|
||||
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used
|
||||
GroupNorm(32,16))
|
||||
# 32 channels, 16 groups (G = 16), thus 2 channels per group used
|
||||
```
|
||||
|
||||
Link : https://arxiv.org/pdf/1803.08494.pdf
|
||||
"""
|
||||
mutable struct GroupNorm{F,V,W,N,T}
|
||||
G::T # number of groups
|
||||
|
|
|
@ -12,7 +12,7 @@ in the background. `cell` should be a model of the form:
|
|||
|
||||
h, y = cell(h, x...)
|
||||
|
||||
For example, here's a recurrent network that keeps a running total of its inputs.
|
||||
For example, here's a recurrent network that keeps a running total of its inputs:
|
||||
|
||||
```julia
|
||||
accum(h, x) = (h+x, x)
|
||||
|
@ -135,8 +135,8 @@ Base.show(io::IO, l::LSTMCell) =
|
|||
"""
|
||||
LSTM(in::Integer, out::Integer)
|
||||
|
||||
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
|
||||
exhibits a longer memory span over sequences.
|
||||
[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)
|
||||
recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences.
|
||||
|
||||
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||
for a good overview of the internals.
|
||||
|
@ -176,8 +176,8 @@ Base.show(io::IO, l::GRUCell) =
|
|||
"""
|
||||
GRU(in::Integer, out::Integer)
|
||||
|
||||
Gated Recurrent Unit layer. Behaves like an RNN but generally
|
||||
exhibits a longer memory span over sequences.
|
||||
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078) layer. Behaves like an
|
||||
RNN but generally exhibits a longer memory span over sequences.
|
||||
|
||||
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||
for a good overview of the internals.
|
||||
|
|
|
@ -73,7 +73,7 @@ computed as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`.
|
|||
`weight` can be `Nothing`, a `Number` or an `AbstractVector`.
|
||||
`weight=nothing` acts like `weight=1` but is faster.
|
||||
|
||||
See also [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref).
|
||||
See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
|
@ -86,13 +86,13 @@ crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _cros
|
|||
"""
|
||||
logitcrossentropy(ŷ, y; weight = 1)
|
||||
|
||||
Return the crossentropy computed after a [`logsoftmax`](@ref) operation;
|
||||
Return the crossentropy computed after a [`Flux.logsoftmax`](@ref) operation;
|
||||
computed as `-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)`.
|
||||
|
||||
`logitcrossentropy(ŷ, y)` is mathematically equivalent to
|
||||
[`crossentropy(softmax(log(ŷ)), y)`](@ref) but it is more numerically stable.
|
||||
[`Flux.crossentropy(softmax(log(ŷ)), y)`](@ref) but it is more numerically stable.
|
||||
|
||||
See also [`crossentropy`](@ref), [`binarycrossentropy`](@ref).
|
||||
See also: [`Flux.crossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
|
@ -107,9 +107,20 @@ end
|
|||
"""
|
||||
binarycrossentropy(ŷ, y; ϵ=eps(ŷ))
|
||||
|
||||
Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability.
|
||||
Return ``-y*\\log(ŷ + ϵ) - (1-y)*\\log(1-ŷ + ϵ)``. The `ϵ` term provides numerical stability.
|
||||
|
||||
Typically, the prediction `ŷ` is given by the output of a [`sigmoid`](@ref) activation.
|
||||
|
||||
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0])
|
||||
3-element Array{Float64,1}:
|
||||
1.424397097347566
|
||||
0.35231664672364077
|
||||
0.8616703662235441
|
||||
```
|
||||
"""
|
||||
binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
||||
|
||||
|
@ -119,10 +130,19 @@ CuArrays.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1
|
|||
"""
|
||||
logitbinarycrossentropy(ŷ, y)
|
||||
|
||||
`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(ŷ), y)`
|
||||
but it is more numerically stable.
|
||||
`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to
|
||||
[`Flux.binarycrossentropy(σ(log(ŷ)), y)`](@ref) but it is more numerically stable.
|
||||
|
||||
See also [`binarycrossentropy`](@ref), [`sigmoid`](@ref), [`logsigmoid`](@ref).
|
||||
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref)
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0])
|
||||
3-element Array{Float64,1}:
|
||||
1.4243970973475661
|
||||
0.35231664672364094
|
||||
0.8616703662235443
|
||||
```
|
||||
"""
|
||||
logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ)
|
||||
|
||||
|
@ -132,22 +152,23 @@ CuArrays.@cufunc logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ)
|
|||
"""
|
||||
normalise(x; dims=1)
|
||||
|
||||
Normalises `x` to mean 0 and standard deviation 1, across the dimensions given by `dims`. Defaults to normalising over columns.
|
||||
Normalise `x` to mean 0 and standard deviation 1 across the dimensions given by `dims`.
|
||||
Defaults to normalising over columns.
|
||||
|
||||
```julia-repl
|
||||
```jldoctest
|
||||
julia> a = reshape(collect(1:9), 3, 3)
|
||||
3×3 Array{Int64,2}:
|
||||
1 4 7
|
||||
2 5 8
|
||||
3 6 9
|
||||
|
||||
julia> normalise(a)
|
||||
julia> Flux.normalise(a)
|
||||
3×3 Array{Float64,2}:
|
||||
-1.22474 -1.22474 -1.22474
|
||||
0.0 0.0 0.0
|
||||
1.22474 1.22474 1.22474
|
||||
|
||||
julia> normalise(a, dims=2)
|
||||
julia> Flux.normalise(a, dims=2)
|
||||
3×3 Array{Float64,2}:
|
||||
-1.22474 0.0 1.22474
|
||||
-1.22474 0.0 1.22474
|
||||
|
@ -191,7 +212,7 @@ Measures the loss given the prediction `ŷ` and true labels `y` (containing 1 o
|
|||
Returns `sum((max.(0, 1 .- ŷ .* y))) / size(y, 2)`
|
||||
|
||||
[Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss)
|
||||
See also [`squared_hinge`](@ref).
|
||||
See also: [`squared_hinge`](@ref)
|
||||
"""
|
||||
hinge(ŷ, y) = sum(max.(0, 1 .- ŷ .* y)) * 1 // size(y, 2)
|
||||
|
||||
|
@ -201,7 +222,7 @@ hinge(ŷ, y) = sum(max.(0, 1 .- ŷ .* y)) * 1 // size(y, 2)
|
|||
Computes squared hinge loss given the prediction `ŷ` and true labels `y` (conatining 1 or -1).
|
||||
Returns `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`
|
||||
|
||||
See also [`hinge`](@ref).
|
||||
See also: [`hinge`](@ref)
|
||||
"""
|
||||
squared_hinge(ŷ, y) = sum((max.(0, 1 .- ŷ .* y)).^2) * 1 // size(y, 2)
|
||||
|
||||
|
|
|
@ -45,22 +45,20 @@ cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.d
|
|||
"""
|
||||
onehot(l, labels[, unk])
|
||||
|
||||
Create an [`OneHotVector`](@ref) wtih `l`-th element be `true` based on possible `labels` set.
|
||||
If `unk` is given, it retruns `onehot(unk, labels)` if the input label `l` is not find in `labels`; otherwise
|
||||
it will error.
|
||||
|
||||
## Examples
|
||||
Create a [`OneHotVector`](@ref) with its `l`-th element `true` based on
|
||||
possible `labels` set.
|
||||
If `unk` is given, return `onehot(unk, labels)` if the input label `l` is not found
|
||||
in `labels`; otherwise it will error.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> using Flux: onehot
|
||||
|
||||
julia> onehot(:b, [:a, :b, :c])
|
||||
julia> Flux.onehot(:b, [:a, :b, :c])
|
||||
3-element Flux.OneHotVector:
|
||||
0
|
||||
1
|
||||
0
|
||||
|
||||
julia> onehot(:c, [:a, :b, :c])
|
||||
julia> Flux.onehot(:c, [:a, :b, :c])
|
||||
3-element Flux.OneHotVector:
|
||||
0
|
||||
0
|
||||
|
@ -85,12 +83,9 @@ end
|
|||
Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `labels` set, returns the
|
||||
`onehot(unk, labels)` if given labels `ls` is not found in set `labels`.
|
||||
|
||||
## Examples
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> using Flux: onehotbatch
|
||||
|
||||
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
|
||||
julia> Flux.onehotbatch([:b, :a, :b], [:a, :b, :c])
|
||||
3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
|
||||
0 1 0
|
||||
1 0 1
|
||||
|
@ -107,13 +102,12 @@ Base.argmax(xs::OneHotVector) = xs.ix
|
|||
|
||||
Inverse operations of [`onehot`](@ref).
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> using Flux: onecold
|
||||
|
||||
julia> onecold([true, false, false], [:a, :b, :c])
|
||||
julia> Flux.onecold([true, false, false], [:a, :b, :c])
|
||||
:a
|
||||
|
||||
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||
julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||
:c
|
||||
```
|
||||
"""
|
||||
|
|
|
@ -6,19 +6,20 @@ const ϵ = 1e-8
|
|||
# TODO: should use weak refs
|
||||
|
||||
"""
|
||||
Descent(η)
|
||||
Descent(η = 0.1)
|
||||
|
||||
Classic gradient descent optimiser with learning rate `η`.
|
||||
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): The amount by which the gradients are discounted before updating the weights. Defaults to `0.1`.
|
||||
# Parameters
|
||||
- Learning rate (`η`): Amount by which the gradients are discounted before updating
|
||||
the weights.
|
||||
|
||||
## Example
|
||||
```julia-repl
|
||||
opt = Descent() # uses default η (0.1)
|
||||
# Examples
|
||||
```julia
|
||||
opt = Descent()
|
||||
|
||||
opt = Descent(0.3) # use provided η
|
||||
opt = Descent(0.3)
|
||||
|
||||
ps = params(model)
|
||||
|
||||
|
@ -40,17 +41,19 @@ function apply!(o::Descent, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
Momentum(η, ρ)
|
||||
Momentum(η = 0.01, ρ = 0.9)
|
||||
|
||||
Gradient descent with learning rate `η` and momentum `ρ`.
|
||||
Gradient descent optimizer with learning rate `η` and momentum `ρ`.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (`η`): Amount by which gradients are discounted before updating the weights. Defaults to `0.01`.
|
||||
- Momentum (`ρ`): Parameter that accelerates descent in the relevant direction and dampens oscillations. Defaults to `0.9`.
|
||||
# Parameters
|
||||
- Learning rate (`η`): Amount by which gradients are discounted before updating the
|
||||
weights.
|
||||
- Momentum (`ρ`): Controls the acceleration of gradient descent in the relevant direction
|
||||
and therefore the dampening of oscillations.
|
||||
|
||||
## Examples
|
||||
# Examples
|
||||
```julia
|
||||
opt = Momentum() # uses defaults of η = 0.01 and ρ = 0.9
|
||||
opt = Momentum()
|
||||
|
||||
opt = Momentum(0.01, 0.99)
|
||||
```
|
||||
|
@ -71,17 +74,18 @@ function apply!(o::Momentum, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
Nesterov(η, ρ)
|
||||
Nesterov(η = 0.001, ρ = 0.9)
|
||||
|
||||
Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
|
||||
Gradient descent optimizer with learning rate `η` and Nesterov momentum `ρ`.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): Amount by which the gradients are dicsounted berfore updating the weights. Defaults to `0.001`.
|
||||
- Nesterov Momentum (ρ): Parameters controlling the amount of nesterov momentum to be applied. Defaults to `0.9`.
|
||||
# Parameters
|
||||
- Learning rate (`η`): Amount by which the gradients are discounted before updating the
|
||||
weights.
|
||||
- Nesterov momentum (`ρ`): The amount of Nesterov momentum to be applied.
|
||||
|
||||
## Examples
|
||||
# Examples
|
||||
```julia
|
||||
opt = Nesterov() # uses defaults η = 0.001 and ρ = 0.9
|
||||
opt = Nesterov()
|
||||
|
||||
opt = Nesterov(0.003, 0.95)
|
||||
```
|
||||
|
@ -103,23 +107,23 @@ function apply!(o::Nesterov, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
RMSProp(η, ρ)
|
||||
RMSProp(η = 0.001, ρ = 0.9)
|
||||
|
||||
Implements the RMSProp algortihm. Often a good choice for recurrent networks. Parameters other than learning rate generally don't need tuning.
|
||||
Optimizer using the
|
||||
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
||||
algorithm. Often a good choice for recurrent networks. Parameters other than learning rate
|
||||
generally don't need tuning.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): Defaults to `0.001`.
|
||||
- Rho (ρ): Defaults to `0.9`.
|
||||
# Parameters
|
||||
- Learning rate (`η`)
|
||||
- Momentum (`ρ`)
|
||||
|
||||
## Examples
|
||||
# Examples
|
||||
```julia
|
||||
opt = RMSProp() # uses default η = 0.001 and ρ = 0.9
|
||||
opt = RMSProp()
|
||||
|
||||
opt = RMSProp(0.002, 0.95)
|
||||
```
|
||||
|
||||
## References
|
||||
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
||||
"""
|
||||
mutable struct RMSProp
|
||||
eta::Float64
|
||||
|
@ -137,23 +141,21 @@ function apply!(o::RMSProp, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
ADAM(η, β::Tuple)
|
||||
ADAM(η = 0.001, β::Tuple = (0.9, 0.999))
|
||||
|
||||
Implements the ADAM optimiser.
|
||||
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
||||
|
||||
## Paramters
|
||||
- Learning Rate (`η`): Defaults to `0.001`.
|
||||
- Beta (`β::Tuple`): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
||||
|
||||
## Examples
|
||||
# Parameters
|
||||
- Learning rate (`η`)
|
||||
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||
second (β2) momentum estimate.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
opt = ADAM() # uses the default η = 0.001 and β = (0.9, 0.999)
|
||||
opt = ADAM()
|
||||
|
||||
opt = ADAM(0.001, (0.9, 0.8))
|
||||
```
|
||||
## References
|
||||
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
||||
"""
|
||||
mutable struct ADAM
|
||||
eta::Float64
|
||||
|
@ -174,24 +176,21 @@ function apply!(o::ADAM, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
RADAM(η, β::Tuple)
|
||||
RADAM(η = 0.001, β::Tuple = (0.9, 0.999))
|
||||
|
||||
Implements the rectified ADAM optimizer.
|
||||
[Rectified ADAM](https://arxiv.org/pdf/1908.03265v1.pdf) optimizer.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): Defaults to `0.001`
|
||||
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
||||
|
||||
## Examples
|
||||
# Parameters
|
||||
- Learning rate (`η`)
|
||||
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||
second (β2) momentum estimate.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
opt = RADAM() # uses the default η = 0.001 and β = (0.9, 0.999)
|
||||
opt = RADAM()
|
||||
|
||||
opt = RADAM(0.001, (0.9, 0.8))
|
||||
```
|
||||
|
||||
## References
|
||||
[RADAM](https://arxiv.org/pdf/1908.03265v1.pdf) optimiser (Rectified ADAM).
|
||||
"""
|
||||
mutable struct RADAM
|
||||
eta::Float64
|
||||
|
@ -219,22 +218,21 @@ function apply!(o::RADAM, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
AdaMax(η, β::Tuple)
|
||||
AdaMax(η = 0.001, β::Tuple = (0.9, 0.999))
|
||||
|
||||
Variant of ADAM based on ∞-norm.
|
||||
[AdaMax](https://arxiv.org/abs/1412.6980v9) is a variant of ADAM based on the ∞-norm.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): Defaults to `0.001`
|
||||
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
||||
# Parameters
|
||||
- Learning rate (`η`)
|
||||
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||
second (β2) momentum estimate.
|
||||
|
||||
## Examples
|
||||
# Examples
|
||||
```julia
|
||||
opt = AdaMax() # uses default η and β
|
||||
opt = AdaMax()
|
||||
|
||||
opt = AdaMax(0.001, (0.9, 0.995))
|
||||
```
|
||||
## References
|
||||
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser.
|
||||
"""
|
||||
mutable struct AdaMax
|
||||
eta::Float64
|
||||
|
@ -255,23 +253,21 @@ function apply!(o::AdaMax, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
ADAGrad(η)
|
||||
ADAGrad(η = 0.1)
|
||||
|
||||
Implements AdaGrad. It has parameter specific learning rates based on how frequently it is updated.
|
||||
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimizer. It has
|
||||
parameter specific learning rates based on how frequently it is updated.
|
||||
Parameters don't need tuning.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): Defaults to `0.1`
|
||||
# Parameters
|
||||
- Learning rate (`η`)
|
||||
|
||||
## Examples
|
||||
# Examples
|
||||
```julia
|
||||
opt = ADAGrad() # uses default η = 0.1
|
||||
opt = ADAGrad()
|
||||
|
||||
opt = ADAGrad(0.001)
|
||||
```
|
||||
|
||||
## References
|
||||
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
|
||||
Parameters don't need tuning.
|
||||
"""
|
||||
mutable struct ADAGrad
|
||||
eta::Float64
|
||||
|
@ -288,21 +284,21 @@ function apply!(o::ADAGrad, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
ADADelta(ρ)
|
||||
ADADelta(ρ = 0.9)
|
||||
|
||||
Version of ADAGrad that adapts learning rate based on a window of past gradient updates. Parameters don't need tuning.
|
||||
[ADADelta](https://arxiv.org/abs/1212.5701) is a version of ADAGrad adapting its learning
|
||||
rate based on a window of past gradient updates.
|
||||
Parameters don't need tuning.
|
||||
|
||||
## Parameters
|
||||
- Rho (ρ): Factor by which gradient is decayed at each time step. Defaults to `0.9`.
|
||||
# Parameters
|
||||
- Rho (`ρ`): Factor by which gradient is decayed at each time step.
|
||||
|
||||
## Examples
|
||||
# Examples
|
||||
```julia
|
||||
opt = ADADelta() # uses default ρ = 0.9
|
||||
opt = ADADelta()
|
||||
|
||||
opt = ADADelta(0.89)
|
||||
```
|
||||
|
||||
## References
|
||||
[ADADelta](https://arxiv.org/abs/1212.5701) optimiser.
|
||||
"""
|
||||
mutable struct ADADelta
|
||||
rho::Float64
|
||||
|
@ -321,22 +317,22 @@ function apply!(o::ADADelta, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
AMSGrad(η, β::Tuple)
|
||||
AMSGrad(η = 0.001, β::Tuple = (0.9, 0.999))
|
||||
|
||||
Implements AMSGrad version of the ADAM optimiser. Parameters don't need tuning.
|
||||
The [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) version of the ADAM
|
||||
optimiser. Parameters don't need tuning.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): Defaults to `0.001`.
|
||||
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
||||
# Parameters
|
||||
- Learning Rate (`η`)
|
||||
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||
second (β2) momentum estimate.
|
||||
|
||||
## Examples
|
||||
# Examples
|
||||
```julia
|
||||
opt = AMSGrad() # uses default η and β
|
||||
opt = AMSGrad()
|
||||
|
||||
opt = AMSGrad(0.001, (0.89, 0.995))
|
||||
```
|
||||
|
||||
## References
|
||||
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser.
|
||||
"""
|
||||
mutable struct AMSGrad
|
||||
eta::Float64
|
||||
|
@ -356,22 +352,22 @@ function apply!(o::AMSGrad, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
NADAM(η, β::Tuple)
|
||||
NADAM(η = 0.001, β::Tuple = (0.9, 0.999))
|
||||
|
||||
Nesterov variant of ADAM. Parameters don't need tuning.
|
||||
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) is a Nesterov variant of ADAM.
|
||||
Parameters don't need tuning.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): Defaults to `0.001`.
|
||||
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
||||
# Parameters
|
||||
- Learning rate (`η`)
|
||||
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||
second (β2) momentum estimate.
|
||||
|
||||
## Examples
|
||||
# Examples
|
||||
```julia
|
||||
opt = NADAM() # uses default η and β
|
||||
opt = NADAM()
|
||||
|
||||
opt = NADAM(0.002, (0.89, 0.995))
|
||||
```
|
||||
|
||||
## References
|
||||
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser.
|
||||
"""
|
||||
mutable struct NADAM
|
||||
eta::Float64
|
||||
|
@ -392,23 +388,23 @@ function apply!(o::NADAM, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
ADAMW(η, β::Tuple, decay)
|
||||
ADAMW(η = 0.001, β::Tuple = (0.9, 0.999), decay = 0)
|
||||
|
||||
Variant of ADAM defined by fixing weight decay regularization.
|
||||
[ADAMW](https://arxiv.org/abs/1711.05101) is a variant of ADAM fixing (as in repairing) its
|
||||
weight decay regularization.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (η): Defaults to `0.001`.
|
||||
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to (0.9, 0.999).
|
||||
- decay: Decay applied to weights during optimisation. Defaults to 0.
|
||||
# Parameters
|
||||
- Learning rate (`η`)
|
||||
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||
second (β2) momentum estimate.
|
||||
- `decay`: Decay applied to weights during optimisation.
|
||||
|
||||
## Examples
|
||||
# Examples
|
||||
```julia
|
||||
opt = ADAMW() # uses default η, β and decay
|
||||
opt = ADAMW()
|
||||
|
||||
opt = ADAMW(0.001, (0.89, 0.995), 0.1)
|
||||
```
|
||||
|
||||
## References
|
||||
[ADAMW](https://arxiv.org/abs/1711.05101)
|
||||
"""
|
||||
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
|
||||
Optimiser(ADAM(η, β), WeightDecay(decay))
|
||||
|
@ -441,14 +437,13 @@ function apply!(o::Optimiser, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
InvDecay(γ)
|
||||
InvDecay(γ = 0.001)
|
||||
|
||||
Applies inverse time decay to an optimiser, i.e., the effective step size at iteration `n` is `eta / (1 + γ * n)` where `eta` is the initial step size. The wrapped optimiser's step size is not modified.
|
||||
Apply inverse time decay to an optimiser, so that the effective step size at
|
||||
iteration `n` is `eta / (1 + γ * n)` where `eta` is the initial step size.
|
||||
The wrapped optimiser's step size is not modified.
|
||||
|
||||
## Parameters
|
||||
- gamma (γ): Defaults to `0.001`
|
||||
|
||||
## Example
|
||||
# Examples
|
||||
```julia
|
||||
Optimiser(InvDecay(..), Opt(..))
|
||||
```
|
||||
|
@ -469,20 +464,23 @@ function apply!(o::InvDecay, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
ExpDecay(eta, decay, decay_step, clip)
|
||||
ExpDecay(eta = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4)
|
||||
|
||||
Discount the learning rate `eta` by a multiplicative factor `decay` every `decay_step` till a minimum of `clip`.
|
||||
Discount the learning rate `eta` by the factor `decay` every `decay_step` steps till
|
||||
a minimum of `clip`.
|
||||
|
||||
## Parameters
|
||||
- Learning Rate (eta): Defaults to `0.001`.
|
||||
- decay: Factor by which the learning rate is discounted. Defaults to `0.1`.
|
||||
- decay_step: Schedules decay operations by setting number of steps between two decay operations. Defaults to `1000`.
|
||||
- clip: Minimum value of learning rate. Defaults to `1e-4`.
|
||||
# Parameters
|
||||
- Learning rate (`eta`)
|
||||
- `decay`: Factor by which the learning rate is discounted.
|
||||
- `decay_step`: Schedule decay operations by setting number of steps between two decay
|
||||
operations.
|
||||
- `clip`: Minimum value of learning rate.
|
||||
|
||||
## Example
|
||||
# Examples
|
||||
To apply exponential decay to an optimiser:
|
||||
```julia
|
||||
Optimiser(ExpDecay(..), Opt(..))
|
||||
|
||||
opt = Optimiser(ExpDecay(), ADAM())
|
||||
```
|
||||
"""
|
||||
|
@ -507,12 +505,12 @@ function apply!(o::ExpDecay, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
WeightDecay(wd)
|
||||
WeightDecay(wd = 0)
|
||||
|
||||
Decays the weight by `wd`
|
||||
Decay weights by `wd`.
|
||||
|
||||
## Parameters
|
||||
- weight decay (wd): 0
|
||||
# Parameters
|
||||
- Weight decay (`wd`)
|
||||
"""
|
||||
mutable struct WeightDecay
|
||||
wd::Real
|
||||
|
|
|
@ -43,9 +43,8 @@ struct StopException <: Exception end
|
|||
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
|
||||
This would trigger the train loop to stop and exit.
|
||||
|
||||
# Examples
|
||||
```julia
|
||||
# Example callback:
|
||||
|
||||
cb = function ()
|
||||
accuracy() > 0.9 && Flux.stop()
|
||||
end
|
||||
|
@ -65,12 +64,12 @@ In case datapoints `d` are of numeric array type, assumes no splatting is needed
|
|||
and computes the gradient of `loss(d)`.
|
||||
|
||||
Takes a callback as keyword argument `cb`. For example, this will print "training"
|
||||
every 10 seconds:
|
||||
every 10 seconds (using [`throttle`](@ref)):
|
||||
|
||||
train!(loss, params, data, opt,
|
||||
cb = throttle(() -> println("training"), 10))
|
||||
|
||||
The callback can call `Flux.stop()` to interrupt the training loop.
|
||||
The callback can call [`Flux.stop()`](@ref) to interrupt the training loop.
|
||||
|
||||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||
"""
|
||||
|
@ -106,11 +105,12 @@ end
|
|||
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
|
||||
training in a REPL.
|
||||
|
||||
```julia
|
||||
julia> @epochs 2 println("hello")
|
||||
INFO: Epoch 1
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.@epochs 2 println("hello")
|
||||
[ Info: Epoch 1
|
||||
hello
|
||||
INFO: Epoch 2
|
||||
[ Info: Epoch 2
|
||||
hello
|
||||
```
|
||||
"""
|
||||
|
|
36
src/utils.jl
36
src/utils.jl
|
@ -125,8 +125,9 @@ unstack(xs, dim) = [copy(selectdim(xs, dim, i)) for i in 1:size(xs, dim)]
|
|||
|
||||
Split `xs` into `n` parts.
|
||||
|
||||
```julia
|
||||
julia> chunk(1:10, 3)
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.chunk(1:10, 3)
|
||||
3-element Array{Array{Int64,1},1}:
|
||||
[1, 2, 3, 4]
|
||||
[5, 6, 7, 8]
|
||||
|
@ -142,11 +143,12 @@ batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)
|
|||
|
||||
Count the number of times that each element of `xs` appears.
|
||||
|
||||
```julia
|
||||
julia> frequencies(['a','b','b'])
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.frequencies(['a','b','b'])
|
||||
Dict{Char,Int64} with 2 entries:
|
||||
'b' => 2
|
||||
'a' => 1
|
||||
'b' => 2
|
||||
```
|
||||
"""
|
||||
function frequencies(xs)
|
||||
|
@ -166,8 +168,9 @@ squeezebatch(x) = reshape(x, head(size(x)))
|
|||
|
||||
Batch the arrays in `xs` into a single array.
|
||||
|
||||
```julia
|
||||
julia> batch([[1,2,3],[4,5,6]])
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.batch([[1,2,3],[4,5,6]])
|
||||
3×2 Array{Int64,2}:
|
||||
1 4
|
||||
2 5
|
||||
|
@ -211,8 +214,9 @@ Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))
|
|||
Take a list of `N` sequences, and turn them into a single sequence where each
|
||||
item is a batch of `N`. Short sequences will be padded by `pad`.
|
||||
|
||||
```julia
|
||||
julia> batchseq([[1, 2, 3], [4, 5]], 0)
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.batchseq([[1, 2, 3], [4, 5]], 0)
|
||||
3-element Array{Array{Int64,1},1}:
|
||||
[1, 4]
|
||||
[2, 5]
|
||||
|
@ -269,11 +273,15 @@ end
|
|||
# Other
|
||||
|
||||
"""
|
||||
Returns a function that when invoked, will only be triggered at most once
|
||||
during `timeout` seconds. Normally, the throttled function will run
|
||||
as much as it can, without ever going more than once per `wait` duration;
|
||||
but if you'd like to disable the execution on the leading edge, pass
|
||||
`leading=false`. To enable execution on the trailing edge, ditto.
|
||||
throttle(f, timeout; leading=true, trailing=false)
|
||||
|
||||
Return a function that when invoked, will only be triggered at most once
|
||||
during `timeout` seconds.
|
||||
|
||||
Normally, the throttled function will run as much as it can, without ever
|
||||
going more than once per `wait` duration; but if you'd like to disable the
|
||||
execution on the leading edge, pass `leading=false`. To enable execution on
|
||||
the trailing edge, pass `trailing=true`.
|
||||
"""
|
||||
function throttle(f, timeout; leading=true, trailing=false)
|
||||
cooldown = true
|
||||
|
|
Loading…
Reference in New Issue