Improve docstrings

Improvements like...
   - fixing typos,
   - removing trailing and double whitespaces,
   - using `jldoctest` blocks where applicable,
   - fixing, updating or correctly setting up existing doctests,
   - improving consistency (for example, always use "# Examples" instead
     of other variants),
   - removing empty lines between docstrings and functions,
   - instead of mentioning keywords, put them into the docstring,
   - adding some missing but useful keywords,
   - adding references (`@ref`),
   - using LaTeX math where applicable, and
   - linking papers.

Debatable stuff that is untouched:
   - BE/AE s/z irregularities ("normalise" versus "normalize") since
     most papers use the AE version while the Flux source code was
     written with BE spelling.
   - Names of normalization functions are capitalized
     ("Batch Normalization" instead of "batch normalization").
This commit is contained in:
janEbert 2019-08-31 11:39:28 +02:00
parent c76b7315ac
commit ab86e350f2
12 changed files with 337 additions and 315 deletions

View File

@ -33,9 +33,10 @@ const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte")
Load the Fashion-MNIST images. Load the Fashion-MNIST images.
Each image is a 28×28 array of `Gray` colour values (see Colors.jl). Each image is a 28×28 array of `Gray` colour values
(see [Colors.jl](https://github.com/JuliaGraphics/Colors.jl)).
Returns the 60,000 training images by default; pass `:test` to retreive the Return the 60,000 training images by default; pass `:test` to retrieve the
10,000 test images. 10,000 test images.
""" """
function images(set = :train) function images(set = :train)
@ -49,10 +50,10 @@ end
labels() labels()
labels(:test) labels(:test)
Load the labels corresponding to each of the images returned from `images()`. Load the labels corresponding to each of the images returned from [`images()`](@ref).
Each label is a number from 0-9. Each label is a number from 0-9.
Returns the 60,000 training labels by default; pass `:test` to retreive the Return the 60,000 training labels by default; pass `:test` to retrieve the
10,000 test labels. 10,000 test labels.
""" """
function labels(set = :train) function labels(set = :train)

View File

@ -2,13 +2,12 @@
Fisher's classic iris dataset. Fisher's classic iris dataset.
Measurements from 3 different species of iris: setosa, versicolor and Measurements from 3 different species of iris: setosa, versicolor and
virginica. There are 50 examples of each species. virginica. There are 50 examples of each species.
There are 4 measurements for each example: sepal length, sepal width, petal There are 4 measurements for each example: sepal length, sepal width,
length and petal width. The measurements are in centimeters. petal length and petal width. The measurements are in centimeters.
The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris). The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris).
""" """
module Iris module Iris
@ -33,9 +32,7 @@ end
Get the labels of the iris dataset, a 150 element array of strings listing the Get the labels of the iris dataset, a 150 element array of strings listing the
species of each example. species of each example.
```jldoctest ```jldoctest; setup = :(Flux.Data.Iris.load())
julia> using Flux
julia> labels = Flux.Data.Iris.labels(); julia> labels = Flux.Data.Iris.labels();
julia> summary(labels) julia> summary(labels)
@ -54,13 +51,11 @@ end
""" """
features() features()
Get the features of the iris dataset. This is a 4x150 matrix of Float64 Get the features of the iris dataset. This is a 4x150 matrix of Float64
elements. It has a row for each feature (sepal length, sepal width, elements. It has a row for each feature (sepal length, sepal width,
petal length, petal width) and a column for each example. petal length, petal width) and a column for each example.
```jldoctest ```jldoctest; setup = :(Flux.Data.Iris.load())
julia> using Flux
julia> features = Flux.Data.Iris.features(); julia> features = Flux.Data.Iris.features();
julia> summary(features) julia> summary(features)

View File

@ -83,9 +83,10 @@ getfeatures(io::IO, index::Integer) = vec(getimage(io, index))
Load the MNIST images. Load the MNIST images.
Each image is a 28×28 array of `Gray` colour values (see Colors.jl). Each image is a 28×28 array of `Gray` colour values
(see [Colors.jl](https://github.com/JuliaGraphics/Colors.jl)).
Returns the 60,000 training images by default; pass `:test` to retreive the Return the 60,000 training images by default; pass `:test` to retrieve the
10,000 test images. 10,000 test images.
""" """
function images(set = :train) function images(set = :train)
@ -99,10 +100,10 @@ end
labels() labels()
labels(:test) labels(:test)
Load the labels corresponding to each of the images returned from `images()`. Load the labels corresponding to each of the images returned from [`images()`](@ref).
Each label is a number from 0-9. Each label is a number from 0-9.
Returns the 60,000 training labels by default; pass `:test` to retreive the Return the 60,000 training labels by default; pass `:test` to retrieve the
10,000 test labels. 10,000 test labels.
""" """
function labels(set = :train) function labels(set = :train)

View File

@ -4,17 +4,23 @@
Chain multiple layers / functions together, so that they are called in sequence Chain multiple layers / functions together, so that they are called in sequence
on a given input. on a given input.
```julia
m = Chain(x -> x^2, x -> x+1)
m(5) == 26
m = Chain(Dense(10, 5), Dense(5, 2))
x = rand(10)
m(x) == m[2](m[1](x))
```
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`. `Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
`m[1:3](x)` will calculate the output of the first three layers. `m[1:3](x)` will calculate the output of the first three layers.
# Examples
```jldoctest
julia> m = Chain(x -> x^2, x -> x+1);
julia> m(5) == 26
true
julia> m = Chain(Dense(10, 5), Dense(5, 2));
julia> x = rand(10);
julia> m(x) == m[2](m[1](x))
true
```
""" """
struct Chain{T<:Tuple} struct Chain{T<:Tuple}
layers::T layers::T
@ -60,6 +66,7 @@ outdims(c::Chain, isize) = foldl(∘, map(l -> (x -> outdims(l, x)), c.layers))(
# only slightly changed to better handle interaction with Zygote @dsweber2 # only slightly changed to better handle interaction with Zygote @dsweber2
""" """
activations(c::Chain, input) activations(c::Chain, input)
Calculate the forward results of each layers in Chain `c` with `input` as model input. Calculate the forward results of each layers in Chain `c` with `input` as model input.
""" """
function activations(c::Chain, input) function activations(c::Chain, input)
@ -78,14 +85,15 @@ extraChain(::Tuple{}, x) = ()
""" """
Dense(in::Integer, out::Integer, σ = identity) Dense(in::Integer, out::Integer, σ = identity)
Creates a traditional `Dense` layer with parameters `W` and `b`. Create a traditional `Dense` layer with parameters `W` and `b`.
y = σ.(W * x .+ b) y = σ.(W * x .+ b)
The input `x` must be a vector of length `in`, or a batch of vectors represented The input `x` must be a vector of length `in`, or a batch of vectors represented
as an `in × N` matrix. The out `y` will be a vector or batch of length `out`. as an `in × N` matrix. The out `y` will be a vector or batch of length `out`.
```julia # Examples
```jldoctest; setup = :(using Random; Random.seed!(0))
julia> d = Dense(5, 2) julia> d = Dense(5, 2)
Dense(5, 2) Dense(5, 2)
@ -145,7 +153,7 @@ outdims(l::Dense, isize) = (size(l.W)[1],)
""" """
Diagonal(in::Integer) Diagonal(in::Integer)
Creates an element-wise linear transformation layer with learnable Create an element-wise linear transformation layer with learnable
vectors `α` and `β`: vectors `α` and `β`:
y = α .* x .+ β y = α .* x .+ β
@ -176,8 +184,8 @@ outdims(l::Diagonal, isize) = (length(l.α),)
""" """
Maxout(over) Maxout(over)
`Maxout` is a neural network layer, which has a number of internal layers, `Maxout` is a neural network layer which has a number of internal layers
which all have the same input, and the maxout returns the elementwise maximium which all receive the same input. The layer returns the elementwise maximium
of the internal layers' outputs. of the internal layers' outputs.
Maxout over linear dense layers satisfies the univeral approximation theorem. Maxout over linear dense layers satisfies the univeral approximation theorem.
@ -196,17 +204,18 @@ end
""" """
Maxout(f, n_alts) Maxout(f, n_alts)
Constructs a Maxout layer over `n_alts` instances of the layer given by `f`. Construct a Maxout layer over `n_alts` instances of the layer given by `f`.
The function takes no arguement and should return some callable layer. The function takes no arguments and should return some callable layer.
Conventionally this is a linear dense layer. Conventionally, this is a linear dense layer.
For example the following example which # Examples
will construct a `Maxout` layer over 4 internal dense linear layers,
each identical in structure (784 inputs, 128 outputs). This constructs a `Maxout` layer over 4 internal dense linear layers, each
identical in structure (784 inputs, 128 outputs):
```julia ```julia
insize = 784 insize = 784
outsize = 128 outsize = 128
Maxout(()->Dense(insize, outsize), 4) Maxout(()->Dense(insize, outsize), 4)
``` ```
""" """
function Maxout(f, n_alts) function Maxout(f, n_alts)
@ -223,16 +232,18 @@ end
outdims(l::Maxout, isize) = outdims(first(l.over), isize) outdims(l::Maxout, isize) = outdims(first(l.over), isize)
""" """
SkipConnection(layers, connection) SkipConnection(layer, connection)
Creates a Skip Connection, of a layer or `Chain` of consecutive layers Create a skip connection which consists of a layer or `Chain` of consecutive
plus a shortcut connection. The connection function will combine the result of the layers layers and a shortcut connection linking the block's input to the output
with the original input, to give the final output. through a user-supplied 2-argument callable. The first argument to the callable
will be propagated through the given `layer` while the second is the unchanged,
"skipped" input.
The simplest 'ResNet'-type connection is just `SkipConnection(layer, +)`, The simplest "ResNet"-type connection is just `SkipConnection(layer, +)`,
and requires the output of the layers to be the same shape as the input. and requires the output of the layers to be the same shape as the input.
Here is a more complicated example: Here is a more complicated example:
``` ```julia
m = Conv((3,3), 4=>7, pad=(1,1)) m = Conv((3,3), 4=>7, pad=(1,1))
x = ones(5,5,4,10); x = ones(5,5,4,10);
size(m(x)) == (5, 5, 7, 10) size(m(x)) == (5, 5, 7, 10)

View File

@ -8,25 +8,26 @@ _convtransoutdims(isize, ksize, ssize, dsize, pad) = (isize .- 1).*ssize .+ 1 .+
expand(N, i::Tuple) = i expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N) expand(N, i::Integer) = ntuple(_ -> i, N)
""" """
Conv(size, in=>out) Conv(size, in => out, σ = identity; init = glorot_uniform,
Conv(size, in=>out, relu) stride = 1, pad = 0, dilation = 1)
Standard convolutional layer. `size` should be a tuple like `(2, 2)`. Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively. `in` and `out` specify the number of input and output channels respectively.
Example: Applying Conv layer to a 1-channel input using a 2x2 window size,
giving us a 16-channel output. Output is activated with ReLU.
size = (2,2)
in = 1
out = 16
Conv((2, 2), 1=>16, relu)
Data should be stored in WHCN order (width, height, # channels, batch size). Data should be stored in WHCN order (width, height, # channels, batch size).
In other words, a 100×100 RGB image would be a `100×100×3×1` array, In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array. and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`. # Examples
Apply a `Conv` layer to a 1-channel input using a 2×2 window size, giving us a
16-channel output. Output is activated with ReLU.
```julia
size = (2,2)
in = 1
out = 16
Conv(size, in => out, relu)
```
""" """
struct Conv{N,M,F,A,V} struct Conv{N,M,F,A,V}
σ::F σ::F
@ -76,8 +77,8 @@ end
""" """
outdims(l::Conv, isize::Tuple) outdims(l::Conv, isize::Tuple)
Calculate the output dimensions given the input dimensions, `isize`. Calculate the output dimensions given the input dimensions `isize`.
Batch size and channel size are ignored as per `NNlib.jl`. Batch size and channel size are ignored as per [NNlib.jl](https://github.com/FluxML/NNlib.jl).
```julia ```julia
m = Conv((3, 3), 3 => 16) m = Conv((3, 3), 3 => 16)
@ -89,17 +90,15 @@ outdims(l::Conv, isize) =
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation)) output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
""" """
ConvTranspose(size, in=>out) ConvTranspose(size, in => out, σ = identity; init = glorot_uniform,
ConvTranspose(size, in=>out, relu) stride = 1, pad = 0, dilation = 1)
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`. Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively. `in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order (width, height, # channels, # batches). Data should be stored in WHCN order (width, height, # channels, batch size).
In other words, a 100×100 RGB image would be a `100×100×3×1` array, In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array. and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
""" """
struct ConvTranspose{N,M,F,A,V} struct ConvTranspose{N,M,F,A,V}
σ::F σ::F
@ -165,18 +164,16 @@ end
outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad) outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad)
""" """
DepthwiseConv(size, in=>out) DepthwiseConv(size, in => out, σ = identity; init = glorot_uniform,
DepthwiseConv(size, in=>out, relu) stride = 1, pad = 0, dilation = 1)
Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`. Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively. `in` and `out` specify the number of input and output channels respectively.
Note that `out` must be an integer multiple of `in`. Note that `out` must be an integer multiple of `in`.
Data should be stored in WHCN order (width, height, # channels, # batches). Data should be stored in WHCN order (width, height, # channels, batch size).
In other words, a 100×100 RGB image would be a `100×100×3×1` array, In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array. and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
""" """
struct DepthwiseConv{N,M,F,A,V} struct DepthwiseConv{N,M,F,A,V}
σ::F σ::F
@ -233,25 +230,26 @@ outdims(l::DepthwiseConv, isize) =
output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation)) output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
""" """
CrossCor(size, in=>out) CrossCor(size, in => out, σ = identity; init = glorot_uniform,
CrossCor(size, in=>out, relu) stride = 1, pad = 0, dilation = 1)
Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`. Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively. `in` and `out` specify the number of input and output channels respectively.
Example: Applying CrossCor layer to a 1-channel input using a 2x2 window size, Data should be stored in WHCN order (width, height, # channels, batch size).
giving us a 16-channel output. Output is activated with ReLU.
size = (2,2)
in = 1
out = 16
CrossCor((2, 2), 1=>16, relu)
Data should be stored in WHCN order (width, height, # channels, # batches).
In other words, a 100×100 RGB image would be a `100×100×3×1` array, In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array. and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`. # Examples
Apply a `CrossCor` layer to a 1-channel input using a 2×2 window size, giving us a
16-channel output. Output is activated with ReLU.
```julia
size = (2,2)
in = 1
out = 16
CrossCor((2, 2), 1=>16, relu)
```
""" """
struct CrossCor{N,M,F,A,V} struct CrossCor{N,M,F,A,V}
σ::F σ::F
@ -357,11 +355,9 @@ function Base.show(io::IO, g::GlobalMeanPool)
end end
""" """
MaxPool(k) MaxPool(k; pad = 0, stride = k)
Max pooling layer. `k` stands for the size of the window for each dimension of the input. Max pooling layer. `k` is the size of the window for each dimension of the input.
Takes the keyword arguments `pad` and `stride`.
""" """
struct MaxPool{N,M} struct MaxPool{N,M}
k::NTuple{N,Int} k::NTuple{N,Int}
@ -388,11 +384,9 @@ end
outdims(l::MaxPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad)) outdims(l::MaxPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
""" """
MeanPool(k) MeanPool(k; pad = 0, stride = k)
Mean pooling layer. `k` stands for the size of the window for each dimension of the input. Mean pooling layer. `k` is the size of the window for each dimension of the input.
Takes the keyword arguments `pad` and `stride`.
""" """
struct MeanPool{N,M} struct MeanPool{N,M}
k::NTuple{N,Int} k::NTuple{N,Int}

View File

@ -10,14 +10,14 @@ _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(s
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) _dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
""" """
dropout(p, dims = :) dropout(x, p; dims = :)
Dropout function. For each input, either sets that input to `0` (with probability The dropout function. For each input, either sets that input to `0` (with probability
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specify the unbroadcasted `p`) or scales it by `1 / (1 - p)`. `dims` specifies the unbroadcasted dimensions,
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is e.g. `dims=1` applies dropout along columns and `dims=2` along rows.
used as a regularisation, i.e. it reduces overfitting during training. This is used as a regularisation, i.e. it reduces overfitting during training.
See also [`Dropout`](@ref). See also the [`Dropout`](@ref) layer.
""" """
dropout(x, p; dims = :) = x dropout(x, p; dims = :) = x
@ -32,7 +32,7 @@ end
A Dropout layer. In the forward pass, applies the [`dropout`](@ref) function on the input. A Dropout layer. In the forward pass, applies the [`dropout`](@ref) function on the input.
Does nothing to the input once [`testmode!`](@ref) is true. Does nothing to the input once [`Flux.testmode!`](@ref) is `true`.
""" """
mutable struct Dropout{F,D} mutable struct Dropout{F,D}
p::F p::F
@ -64,9 +64,9 @@ end
""" """
AlphaDropout(p) AlphaDropout(p)
A dropout layer. It is used in Self-Normalizing Neural Networks. A dropout layer. It is used in
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf) [Self-Normalizing Neural Networks](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf).
The AlphaDropout layer ensures that mean and variance of activations remains the same as before. The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
Does nothing to the input once [`testmode!`](@ref) is true. Does nothing to the input once [`testmode!`](@ref) is true.
@ -100,8 +100,8 @@ testmode!(m::AlphaDropout, mode = true) =
LayerNorm(h::Integer) LayerNorm(h::Integer)
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
used with recurrent hidden states of size `h`. Normalises the mean/stddev of used with recurrent hidden states of size `h`. Normalises the mean and standard
each input before applying a per-neuron gain/bias. deviation of each input before applying a per-neuron gain/bias.
""" """
struct LayerNorm{T} struct LayerNorm{T}
diag::Diagonal{T} diag::Diagonal{T}
@ -139,7 +139,7 @@ Use [`testmode!`](@ref) during inference.
See [Batch Normalization: Accelerating Deep Network Training by Reducing See [Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf). Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
Example: # Examples
```julia ```julia
m = Chain( m = Chain(
Dense(28^2, 64), Dense(28^2, 64),
@ -234,7 +234,7 @@ Use [`testmode!`](@ref) during inference.
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022). See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
Example: # Examples
```julia ```julia
m = Chain( m = Chain(
Dense(28^2, 64), Dense(28^2, 64),
@ -316,28 +316,27 @@ function Base.show(io::IO, l::InstanceNorm)
end end
""" """
Group Normalization. GroupNorm(chs::Integer, G::Integer, λ = identity;
This layer can outperform Batch-Normalization and Instance-Normalization. initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
ϵ = 1f-5, momentum = 0.1f0)
GroupNorm(chs::Integer, G::Integer, λ = identity; [Group Normalization](https://arxiv.org/pdf/1803.08494.pdf) layer.
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), This layer can outperform Batch Normalization and Instance Normalization.
ϵ = 1f-5, momentum = 0.1f0)
``chs`` is the number of channels, the channel dimension of your input. `chs` is the number of channels, the channel dimension of your input.
For an array of N dimensions, the (N-1)th index is the channel dimension. For an array of N dimensions, the `N-1`th index is the channel dimension.
``G`` is the number of groups along which the statistics would be computed. `G` is the number of groups along which the statistics are computed.
The number of channels must be an integer multiple of the number of groups. The number of channels must be an integer multiple of the number of groups.
Use [`testmode!`](@ref) during inference. Use [`testmode!`](@ref) during inference.
Example: # Examples
``` ```julia
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1), m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used GroupNorm(32,16))
# 32 channels, 16 groups (G = 16), thus 2 channels per group used
``` ```
Link : https://arxiv.org/pdf/1803.08494.pdf
""" """
mutable struct GroupNorm{F,V,W,N,T} mutable struct GroupNorm{F,V,W,N,T}
G::T # number of groups G::T # number of groups

View File

@ -12,7 +12,7 @@ in the background. `cell` should be a model of the form:
h, y = cell(h, x...) h, y = cell(h, x...)
For example, here's a recurrent network that keeps a running total of its inputs. For example, here's a recurrent network that keeps a running total of its inputs:
```julia ```julia
accum(h, x) = (h+x, x) accum(h, x) = (h+x, x)
@ -135,8 +135,8 @@ Base.show(io::IO, l::LSTMCell) =
""" """
LSTM(in::Integer, out::Integer) LSTM(in::Integer, out::Integer)
Long Short Term Memory recurrent layer. Behaves like an RNN but generally [Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)
exhibits a longer memory span over sequences. recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals. for a good overview of the internals.
@ -176,8 +176,8 @@ Base.show(io::IO, l::GRUCell) =
""" """
GRU(in::Integer, out::Integer) GRU(in::Integer, out::Integer)
Gated Recurrent Unit layer. Behaves like an RNN but generally [Gated Recurrent Unit](https://arxiv.org/abs/1406.1078) layer. Behaves like an
exhibits a longer memory span over sequences. RNN but generally exhibits a longer memory span over sequences.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals. for a good overview of the internals.

View File

@ -73,7 +73,7 @@ computed as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`.
`weight` can be `Nothing`, a `Number` or an `AbstractVector`. `weight` can be `Nothing`, a `Number` or an `AbstractVector`.
`weight=nothing` acts like `weight=1` but is faster. `weight=nothing` acts like `weight=1` but is faster.
See also [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref). See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
# Examples # Examples
```jldoctest ```jldoctest
@ -86,13 +86,13 @@ crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _cros
""" """
logitcrossentropy(, y; weight = 1) logitcrossentropy(, y; weight = 1)
Return the crossentropy computed after a [`logsoftmax`](@ref) operation; Return the crossentropy computed after a [`Flux.logsoftmax`](@ref) operation;
computed as `-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)`. computed as `-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)`.
`logitcrossentropy(ŷ, y)` is mathematically equivalent to `logitcrossentropy(ŷ, y)` is mathematically equivalent to
[`crossentropy(softmax(log(ŷ)), y)`](@ref) but it is more numerically stable. [`Flux.crossentropy(softmax(log(ŷ)), y)`](@ref) but it is more numerically stable.
See also [`crossentropy`](@ref), [`binarycrossentropy`](@ref). See also: [`Flux.crossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
# Examples # Examples
```jldoctest ```jldoctest
@ -107,9 +107,20 @@ end
""" """
binarycrossentropy(, y; ϵ=eps()) binarycrossentropy(, y; ϵ=eps())
Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability. Return ``-y*\\log( + ϵ) - (1-y)*\\log(1- + ϵ)``. The `ϵ` term provides numerical stability.
Typically, the prediction `` is given by the output of a [`sigmoid`](@ref) activation. Typically, the prediction `` is given by the output of a [`sigmoid`](@ref) activation.
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
# Examples
```jldoctest
julia> Flux.binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0])
3-element Array{Float64,1}:
1.424397097347566
0.35231664672364077
0.8616703662235441
```
""" """
binarycrossentropy(, y; ϵ=eps()) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ) binarycrossentropy(, y; ϵ=eps()) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ)
@ -119,10 +130,19 @@ CuArrays.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1
""" """
logitbinarycrossentropy(ŷ, y) logitbinarycrossentropy(ŷ, y)
`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(ŷ), y)` `logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to
but it is more numerically stable. [`Flux.binarycrossentropy(σ(log(ŷ)), y)`](@ref) but it is more numerically stable.
See also [`binarycrossentropy`](@ref), [`sigmoid`](@ref), [`logsigmoid`](@ref). See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref)
# Examples
```jldoctest
julia> Flux.logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0])
3-element Array{Float64,1}:
1.4243970973475661
0.35231664672364094
0.8616703662235443
```
""" """
logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ() logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ()
@ -132,26 +152,27 @@ CuArrays.@cufunc logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ)
""" """
normalise(x; dims=1) normalise(x; dims=1)
Normalises `x` to mean 0 and standard deviation 1, across the dimensions given by `dims`. Defaults to normalising over columns. Normalise `x` to mean 0 and standard deviation 1 across the dimensions given by `dims`.
Defaults to normalising over columns.
```julia-repl ```jldoctest
julia> a = reshape(collect(1:9), 3, 3) julia> a = reshape(collect(1:9), 3, 3)
3×3 Array{Int64,2}: 3×3 Array{Int64,2}:
1 4 7 1 4 7
2 5 8 2 5 8
3 6 9 3 6 9
julia> normalise(a) julia> Flux.normalise(a)
3×3 Array{Float64,2}: 3×3 Array{Float64,2}:
-1.22474 -1.22474 -1.22474 -1.22474 -1.22474 -1.22474
0.0 0.0 0.0 0.0 0.0 0.0
1.22474 1.22474 1.22474 1.22474 1.22474 1.22474
julia> normalise(a, dims=2) julia> Flux.normalise(a, dims=2)
3×3 Array{Float64,2}: 3×3 Array{Float64,2}:
-1.22474 0.0 1.22474 -1.22474 0.0 1.22474
-1.22474 0.0 1.22474 -1.22474 0.0 1.22474
-1.22474 0.0 1.22474 -1.22474 0.0 1.22474
``` ```
""" """
function normalise(x::AbstractArray; dims=1) function normalise(x::AbstractArray; dims=1)
@ -191,7 +212,7 @@ Measures the loss given the prediction `ŷ` and true labels `y` (containing 1 o
Returns `sum((max.(0, 1 .- ŷ .* y))) / size(y, 2)` Returns `sum((max.(0, 1 .- ŷ .* y))) / size(y, 2)`
[Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss) [Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss)
See also [`squared_hinge`](@ref). See also: [`squared_hinge`](@ref)
""" """
hinge(, y) = sum(max.(0, 1 .- .* y)) * 1 // size(y, 2) hinge(, y) = sum(max.(0, 1 .- .* y)) * 1 // size(y, 2)
@ -201,7 +222,7 @@ hinge(ŷ, y) = sum(max.(0, 1 .- ŷ .* y)) * 1 // size(y, 2)
Computes squared hinge loss given the prediction `` and true labels `y` (conatining 1 or -1). Computes squared hinge loss given the prediction `` and true labels `y` (conatining 1 or -1).
Returns `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)` Returns `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`
See also [`hinge`](@ref). See also: [`hinge`](@ref)
""" """
squared_hinge(, y) = sum((max.(0, 1 .- .* y)).^2) * 1 // size(y, 2) squared_hinge(, y) = sum((max.(0, 1 .- .* y)).^2) * 1 // size(y, 2)

View File

@ -45,22 +45,20 @@ cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.d
""" """
onehot(l, labels[, unk]) onehot(l, labels[, unk])
Create an [`OneHotVector`](@ref) wtih `l`-th element be `true` based on possible `labels` set. Create a [`OneHotVector`](@ref) with its `l`-th element `true` based on
If `unk` is given, it retruns `onehot(unk, labels)` if the input label `l` is not find in `labels`; otherwise possible `labels` set.
it will error. If `unk` is given, return `onehot(unk, labels)` if the input label `l` is not found
in `labels`; otherwise it will error.
## Examples
# Examples
```jldoctest ```jldoctest
julia> using Flux: onehot julia> Flux.onehot(:b, [:a, :b, :c])
julia> onehot(:b, [:a, :b, :c])
3-element Flux.OneHotVector: 3-element Flux.OneHotVector:
0 0
1 1
0 0
julia> onehot(:c, [:a, :b, :c]) julia> Flux.onehot(:c, [:a, :b, :c])
3-element Flux.OneHotVector: 3-element Flux.OneHotVector:
0 0
0 0
@ -85,12 +83,9 @@ end
Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `labels` set, returns the Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `labels` set, returns the
`onehot(unk, labels)` if given labels `ls` is not found in set `labels`. `onehot(unk, labels)` if given labels `ls` is not found in set `labels`.
## Examples # Examples
```jldoctest ```jldoctest
julia> using Flux: onehotbatch julia> Flux.onehotbatch([:b, :a, :b], [:a, :b, :c])
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}: 3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
0 1 0 0 1 0
1 0 1 1 0 1
@ -107,13 +102,12 @@ Base.argmax(xs::OneHotVector) = xs.ix
Inverse operations of [`onehot`](@ref). Inverse operations of [`onehot`](@ref).
# Examples
```jldoctest ```jldoctest
julia> using Flux: onecold julia> Flux.onecold([true, false, false], [:a, :b, :c])
julia> onecold([true, false, false], [:a, :b, :c])
:a :a
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c]) julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c])
:c :c
``` ```
""" """

View File

@ -6,19 +6,20 @@ const ϵ = 1e-8
# TODO: should use weak refs # TODO: should use weak refs
""" """
Descent(η) Descent(η = 0.1)
Classic gradient descent optimiser with learning rate `η`. Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp` For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`
## Parameters # Parameters
- Learning Rate (η): The amount by which the gradients are discounted before updating the weights. Defaults to `0.1`. - Learning rate (`η`): Amount by which the gradients are discounted before updating
the weights.
## Example # Examples
```julia-repl ```julia
opt = Descent() # uses default η (0.1) opt = Descent()
opt = Descent(0.3) # use provided η opt = Descent(0.3)
ps = params(model) ps = params(model)
@ -40,17 +41,19 @@ function apply!(o::Descent, x, Δ)
end end
""" """
Momentum(η, ρ) Momentum(η = 0.01, ρ = 0.9)
Gradient descent with learning rate `η` and momentum `ρ`. Gradient descent optimizer with learning rate `η` and momentum `ρ`.
## Parameters # Parameters
- Learning Rate (`η`): Amount by which gradients are discounted before updating the weights. Defaults to `0.01`. - Learning rate (`η`): Amount by which gradients are discounted before updating the
- Momentum (`ρ`): Parameter that accelerates descent in the relevant direction and dampens oscillations. Defaults to `0.9`. weights.
- Momentum (`ρ`): Controls the acceleration of gradient descent in the relevant direction
and therefore the dampening of oscillations.
## Examples # Examples
```julia ```julia
opt = Momentum() # uses defaults of η = 0.01 and ρ = 0.9 opt = Momentum()
opt = Momentum(0.01, 0.99) opt = Momentum(0.01, 0.99)
``` ```
@ -71,17 +74,18 @@ function apply!(o::Momentum, x, Δ)
end end
""" """
Nesterov(η, ρ) Nesterov(η = 0.001, ρ = 0.9)
Gradient descent with learning rate `η` and Nesterov momentum `ρ`. Gradient descent optimizer with learning rate `η` and Nesterov momentum `ρ`.
## Parameters # Parameters
- Learning Rate (η): Amount by which the gradients are dicsounted berfore updating the weights. Defaults to `0.001`. - Learning rate (`η`): Amount by which the gradients are discounted before updating the
- Nesterov Momentum (ρ): Parameters controlling the amount of nesterov momentum to be applied. Defaults to `0.9`. weights.
- Nesterov momentum (`ρ`): The amount of Nesterov momentum to be applied.
## Examples # Examples
```julia ```julia
opt = Nesterov() # uses defaults η = 0.001 and ρ = 0.9 opt = Nesterov()
opt = Nesterov(0.003, 0.95) opt = Nesterov(0.003, 0.95)
``` ```
@ -103,23 +107,23 @@ function apply!(o::Nesterov, x, Δ)
end end
""" """
RMSProp(η, ρ) RMSProp(η = 0.001, ρ = 0.9)
Implements the RMSProp algortihm. Often a good choice for recurrent networks. Parameters other than learning rate generally don't need tuning. Optimizer using the
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
algorithm. Often a good choice for recurrent networks. Parameters other than learning rate
generally don't need tuning.
## Parameters # Parameters
- Learning Rate (η): Defaults to `0.001`. - Learning rate (`η`)
- Rho (ρ): Defaults to `0.9`. - Momentum (`ρ`)
## Examples # Examples
```julia ```julia
opt = RMSProp() # uses default η = 0.001 and ρ = 0.9 opt = RMSProp()
opt = RMSProp(0.002, 0.95) opt = RMSProp(0.002, 0.95)
``` ```
## References
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
""" """
mutable struct RMSProp mutable struct RMSProp
eta::Float64 eta::Float64
@ -137,23 +141,21 @@ function apply!(o::RMSProp, x, Δ)
end end
""" """
ADAM(η, β::Tuple) ADAM(η = 0.001, β::Tuple = (0.9, 0.999))
Implements the ADAM optimiser. [ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
## Paramters # Parameters
- Learning Rate (`η`): Defaults to `0.001`. - Learning rate (`η`)
- Beta (`β::Tuple`): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`. - Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.
## Examples
# Examples
```julia ```julia
opt = ADAM() # uses the default η = 0.001 and β = (0.9, 0.999) opt = ADAM()
opt = ADAM(0.001, (0.9, 0.8)) opt = ADAM(0.001, (0.9, 0.8))
``` ```
## References
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
""" """
mutable struct ADAM mutable struct ADAM
eta::Float64 eta::Float64
@ -174,24 +176,21 @@ function apply!(o::ADAM, x, Δ)
end end
""" """
RADAM(η, β::Tuple) RADAM(η = 0.001, β::Tuple = (0.9, 0.999))
Implements the rectified ADAM optimizer. [Rectified ADAM](https://arxiv.org/pdf/1908.03265v1.pdf) optimizer.
## Parameters # Parameters
- Learning Rate (η): Defaults to `0.001` - Learning rate (`η`)
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`. - Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.
## Examples
# Examples
```julia ```julia
opt = RADAM() # uses the default η = 0.001 and β = (0.9, 0.999) opt = RADAM()
opt = RADAM(0.001, (0.9, 0.8)) opt = RADAM(0.001, (0.9, 0.8))
``` ```
## References
[RADAM](https://arxiv.org/pdf/1908.03265v1.pdf) optimiser (Rectified ADAM).
""" """
mutable struct RADAM mutable struct RADAM
eta::Float64 eta::Float64
@ -219,22 +218,21 @@ function apply!(o::RADAM, x, Δ)
end end
""" """
AdaMax(η, β::Tuple) AdaMax(η = 0.001, β::Tuple = (0.9, 0.999))
Variant of ADAM based on -norm. [AdaMax](https://arxiv.org/abs/1412.6980v9) is a variant of ADAM based on the -norm.
## Parameters # Parameters
- Learning Rate (η): Defaults to `0.001` - Learning rate (`η`)
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`. - Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.
## Examples # Examples
```julia ```julia
opt = AdaMax() # uses default η and β opt = AdaMax()
opt = AdaMax(0.001, (0.9, 0.995)) opt = AdaMax(0.001, (0.9, 0.995))
``` ```
## References
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser.
""" """
mutable struct AdaMax mutable struct AdaMax
eta::Float64 eta::Float64
@ -255,23 +253,21 @@ function apply!(o::AdaMax, x, Δ)
end end
""" """
ADAGrad(η) ADAGrad(η = 0.1)
Implements AdaGrad. It has parameter specific learning rates based on how frequently it is updated. [ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimizer. It has
parameter specific learning rates based on how frequently it is updated.
Parameters don't need tuning.
## Parameters # Parameters
- Learning Rate (η): Defaults to `0.1` - Learning rate (`η`)
## Examples # Examples
```julia ```julia
opt = ADAGrad() # uses default η = 0.1 opt = ADAGrad()
opt = ADAGrad(0.001) opt = ADAGrad(0.001)
``` ```
## References
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning.
""" """
mutable struct ADAGrad mutable struct ADAGrad
eta::Float64 eta::Float64
@ -288,21 +284,21 @@ function apply!(o::ADAGrad, x, Δ)
end end
""" """
ADADelta(ρ) ADADelta(ρ = 0.9)
Version of ADAGrad that adapts learning rate based on a window of past gradient updates. Parameters don't need tuning. [ADADelta](https://arxiv.org/abs/1212.5701) is a version of ADAGrad adapting its learning
rate based on a window of past gradient updates.
Parameters don't need tuning.
## Parameters # Parameters
- Rho (ρ): Factor by which gradient is decayed at each time step. Defaults to `0.9`. - Rho (`ρ`): Factor by which gradient is decayed at each time step.
## Examples # Examples
```julia ```julia
opt = ADADelta() # uses default ρ = 0.9 opt = ADADelta()
opt = ADADelta(0.89) opt = ADADelta(0.89)
``` ```
## References
[ADADelta](https://arxiv.org/abs/1212.5701) optimiser.
""" """
mutable struct ADADelta mutable struct ADADelta
rho::Float64 rho::Float64
@ -321,22 +317,22 @@ function apply!(o::ADADelta, x, Δ)
end end
""" """
AMSGrad(η, β::Tuple) AMSGrad(η = 0.001, β::Tuple = (0.9, 0.999))
Implements AMSGrad version of the ADAM optimiser. Parameters don't need tuning. The [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) version of the ADAM
optimiser. Parameters don't need tuning.
## Parameters # Parameters
- Learning Rate (η): Defaults to `0.001`. - Learning Rate (`η`)
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`. - Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.
## Examples # Examples
```julia ```julia
opt = AMSGrad() # uses default η and β opt = AMSGrad()
opt = AMSGrad(0.001, (0.89, 0.995)) opt = AMSGrad(0.001, (0.89, 0.995))
``` ```
## References
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser.
""" """
mutable struct AMSGrad mutable struct AMSGrad
eta::Float64 eta::Float64
@ -356,22 +352,22 @@ function apply!(o::AMSGrad, x, Δ)
end end
""" """
NADAM(η, β::Tuple) NADAM(η = 0.001, β::Tuple = (0.9, 0.999))
Nesterov variant of ADAM. Parameters don't need tuning. [NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) is a Nesterov variant of ADAM.
Parameters don't need tuning.
## Parameters # Parameters
- Learning Rate (η): Defaults to `0.001`. - Learning rate (`η`)
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`. - Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.
## Examples # Examples
```julia ```julia
opt = NADAM() # uses default η and β opt = NADAM()
opt = NADAM(0.002, (0.89, 0.995)) opt = NADAM(0.002, (0.89, 0.995))
``` ```
## References
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser.
""" """
mutable struct NADAM mutable struct NADAM
eta::Float64 eta::Float64
@ -392,23 +388,23 @@ function apply!(o::NADAM, x, Δ)
end end
""" """
ADAMW(η, β::Tuple, decay) ADAMW(η = 0.001, β::Tuple = (0.9, 0.999), decay = 0)
Variant of ADAM defined by fixing weight decay regularization. [ADAMW](https://arxiv.org/abs/1711.05101) is a variant of ADAM fixing (as in repairing) its
weight decay regularization.
## Parameters # Parameters
- Learning Rate (η): Defaults to `0.001`. - Learning rate (`η`)
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to (0.9, 0.999). - Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
- decay: Decay applied to weights during optimisation. Defaults to 0. second (β2) momentum estimate.
- `decay`: Decay applied to weights during optimisation.
## Examples # Examples
```julia ```julia
opt = ADAMW() # uses default η, β and decay opt = ADAMW()
opt = ADAMW(0.001, (0.89, 0.995), 0.1) opt = ADAMW(0.001, (0.89, 0.995), 0.1)
``` ```
## References
[ADAMW](https://arxiv.org/abs/1711.05101)
""" """
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) = ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
Optimiser(ADAM(η, β), WeightDecay(decay)) Optimiser(ADAM(η, β), WeightDecay(decay))
@ -441,14 +437,13 @@ function apply!(o::Optimiser, x, Δ)
end end
""" """
InvDecay(γ) InvDecay(γ = 0.001)
Applies inverse time decay to an optimiser, i.e., the effective step size at iteration `n` is `eta / (1 + γ * n)` where `eta` is the initial step size. The wrapped optimiser's step size is not modified. Apply inverse time decay to an optimiser, so that the effective step size at
iteration `n` is `eta / (1 + γ * n)` where `eta` is the initial step size.
The wrapped optimiser's step size is not modified.
## Parameters # Examples
- gamma (γ): Defaults to `0.001`
## Example
```julia ```julia
Optimiser(InvDecay(..), Opt(..)) Optimiser(InvDecay(..), Opt(..))
``` ```
@ -469,20 +464,23 @@ function apply!(o::InvDecay, x, Δ)
end end
""" """
ExpDecay(eta, decay, decay_step, clip) ExpDecay(eta = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4)
Discount the learning rate `eta` by a multiplicative factor `decay` every `decay_step` till a minimum of `clip`. Discount the learning rate `eta` by the factor `decay` every `decay_step` steps till
a minimum of `clip`.
## Parameters # Parameters
- Learning Rate (eta): Defaults to `0.001`. - Learning rate (`eta`)
- decay: Factor by which the learning rate is discounted. Defaults to `0.1`. - `decay`: Factor by which the learning rate is discounted.
- decay_step: Schedules decay operations by setting number of steps between two decay operations. Defaults to `1000`. - `decay_step`: Schedule decay operations by setting number of steps between two decay
- clip: Minimum value of learning rate. Defaults to `1e-4`. operations.
- `clip`: Minimum value of learning rate.
## Example # Examples
To apply exponential decay to an optimiser: To apply exponential decay to an optimiser:
```julia ```julia
Optimiser(ExpDecay(..), Opt(..)) Optimiser(ExpDecay(..), Opt(..))
opt = Optimiser(ExpDecay(), ADAM()) opt = Optimiser(ExpDecay(), ADAM())
``` ```
""" """
@ -507,12 +505,12 @@ function apply!(o::ExpDecay, x, Δ)
end end
""" """
WeightDecay(wd) WeightDecay(wd = 0)
Decays the weight by `wd` Decay weights by `wd`.
## Parameters # Parameters
- weight decay (wd): 0 - Weight decay (`wd`)
""" """
mutable struct WeightDecay mutable struct WeightDecay
wd::Real wd::Real

View File

@ -43,9 +43,8 @@ struct StopException <: Exception end
Call `Flux.stop()` in a callback to indicate when a callback condition is met. Call `Flux.stop()` in a callback to indicate when a callback condition is met.
This would trigger the train loop to stop and exit. This would trigger the train loop to stop and exit.
# Examples
```julia ```julia
# Example callback:
cb = function () cb = function ()
accuracy() > 0.9 && Flux.stop() accuracy() > 0.9 && Flux.stop()
end end
@ -65,12 +64,12 @@ In case datapoints `d` are of numeric array type, assumes no splatting is needed
and computes the gradient of `loss(d)`. and computes the gradient of `loss(d)`.
Takes a callback as keyword argument `cb`. For example, this will print "training" Takes a callback as keyword argument `cb`. For example, this will print "training"
every 10 seconds: every 10 seconds (using [`throttle`](@ref)):
train!(loss, params, data, opt, train!(loss, params, data, opt,
cb = throttle(() -> println("training"), 10)) cb = throttle(() -> println("training"), 10))
The callback can call `Flux.stop()` to interrupt the training loop. The callback can call [`Flux.stop()`](@ref) to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
""" """
@ -106,11 +105,12 @@ end
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
training in a REPL. training in a REPL.
```julia # Examples
julia> @epochs 2 println("hello") ```jldoctest
INFO: Epoch 1 julia> Flux.@epochs 2 println("hello")
[ Info: Epoch 1
hello hello
INFO: Epoch 2 [ Info: Epoch 2
hello hello
``` ```
""" """

View File

@ -125,8 +125,9 @@ unstack(xs, dim) = [copy(selectdim(xs, dim, i)) for i in 1:size(xs, dim)]
Split `xs` into `n` parts. Split `xs` into `n` parts.
```julia # Examples
julia> chunk(1:10, 3) ```jldoctest
julia> Flux.chunk(1:10, 3)
3-element Array{Array{Int64,1},1}: 3-element Array{Array{Int64,1},1}:
[1, 2, 3, 4] [1, 2, 3, 4]
[5, 6, 7, 8] [5, 6, 7, 8]
@ -142,11 +143,12 @@ batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)
Count the number of times that each element of `xs` appears. Count the number of times that each element of `xs` appears.
```julia # Examples
julia> frequencies(['a','b','b']) ```jldoctest
julia> Flux.frequencies(['a','b','b'])
Dict{Char,Int64} with 2 entries: Dict{Char,Int64} with 2 entries:
'b' => 2
'a' => 1 'a' => 1
'b' => 2
``` ```
""" """
function frequencies(xs) function frequencies(xs)
@ -166,8 +168,9 @@ squeezebatch(x) = reshape(x, head(size(x)))
Batch the arrays in `xs` into a single array. Batch the arrays in `xs` into a single array.
```julia # Examples
julia> batch([[1,2,3],[4,5,6]]) ```jldoctest
julia> Flux.batch([[1,2,3],[4,5,6]])
3×2 Array{Int64,2}: 3×2 Array{Int64,2}:
1 4 1 4
2 5 2 5
@ -211,8 +214,9 @@ Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))
Take a list of `N` sequences, and turn them into a single sequence where each Take a list of `N` sequences, and turn them into a single sequence where each
item is a batch of `N`. Short sequences will be padded by `pad`. item is a batch of `N`. Short sequences will be padded by `pad`.
```julia # Examples
julia> batchseq([[1, 2, 3], [4, 5]], 0) ```jldoctest
julia> Flux.batchseq([[1, 2, 3], [4, 5]], 0)
3-element Array{Array{Int64,1},1}: 3-element Array{Array{Int64,1},1}:
[1, 4] [1, 4]
[2, 5] [2, 5]
@ -269,11 +273,15 @@ end
# Other # Other
""" """
Returns a function that when invoked, will only be triggered at most once throttle(f, timeout; leading=true, trailing=false)
during `timeout` seconds. Normally, the throttled function will run
as much as it can, without ever going more than once per `wait` duration; Return a function that when invoked, will only be triggered at most once
but if you'd like to disable the execution on the leading edge, pass during `timeout` seconds.
`leading=false`. To enable execution on the trailing edge, ditto.
Normally, the throttled function will run as much as it can, without ever
going more than once per `wait` duration; but if you'd like to disable the
execution on the leading edge, pass `leading=false`. To enable execution on
the trailing edge, pass `trailing=true`.
""" """
function throttle(f, timeout; leading=true, trailing=false) function throttle(f, timeout; leading=true, trailing=false)
cooldown = true cooldown = true