853: Improve docs r=CarloLucibello a=janEbert

If you disagree with any of the changes, please tell me what to reverse or fix.
I am unsure about the docstrings I added to `src/utils.jl` for `unsqueeze` and
the `[un]stack` functions so please give those a more detailed look.

Update Documenter.jl version for new features, fix deprecation warnings in
`docs/make.jl` and import Flux for all doctests.
Add missing docstrings to `src/utils.jl`, `src/layers/stateless.jl` and `src/data/`; add
these and other missing functions to Markdown docs.

Improve docstrings by...
   - fixing typos,
   - removing trailing or double whitespaces,
   - using `jldoctest` blocks where applicable,
   - fixing, updating or correctly setting up existing doctests,
   - improving consistency (for example, always use "# Examples" instead
     of other variants),
   - removing empty lines between docstrings and functions,
   - instead of mentioning keywords, put them into the docstring,
   - adding some missing but useful keywords,
   - adding references (`@ref`),
   - using LaTeX math where applicable, and
   - linking papers.

Debatable stuff that is untouched:
   - BE/AE s/z irregularities (e.g. "normalise" versus "normalize") since
     most papers use the AE version while the Flux source code was
     written with BE spelling.
   - Names of normalization functions are capitalized
     ("Batch Normalization" instead of "batch normalization").
   - Default values in argument lists have spaces around the equals sign (`arg = x` instead of `arg=x`).

Co-authored-by: janEbert <janpublicebert@posteo.net>
This commit is contained in:
bors[bot] 2020-04-06 13:47:42 +00:00 committed by GitHub
commit 7a32a703f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 768 additions and 441 deletions

View File

@ -1,6 +1,8 @@
using Documenter, Flux, NNlib
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
makedocs(modules=[Flux, NNlib],
doctest = VERSION >= v"1.4",
sitename = "Flux",
pages = ["Home" => "index.md",
"Building Models" =>
@ -19,11 +21,15 @@ makedocs(modules=[Flux, NNlib],
"GPU Support" => "gpu.md",
"Saving & Loading" => "saving.md",
"The Julia Ecosystem" => "ecosystem.md",
"Utility Functions" => "utilities.md",
"Performance Tips" => "performance.md",
"Datasets" => "datasets.md",
"Community" => "community.md"],
format = Documenter.HTML(assets = ["assets/flux.css"],
format = Documenter.HTML(
analytics = "UA-36890222-9",
prettyurls = haskey(ENV, "CI")))
assets = ["assets/flux.css"],
prettyurls = get(ENV, "CI", nothing) == "true"),
)
deploydocs(repo = "github.com/FluxML/Flux.jl.git",
target = "build",

View File

@ -31,6 +31,11 @@ julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
:c
```
```@docs
Flux.onehot
Flux.onecold
```
## Batches
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `onecold` treats matrices as batches.
@ -52,3 +57,7 @@ julia> onecold(ans, [:a, :b, :c])
```
Note that these operations returned `OneHotVector` and `OneHotMatrix` rather than `Array`s. `OneHotVector`s behave like normal vectors but avoid any unnecessary cost compared to using an integer index directly. For example, multiplying a matrix with a one-hot vector simply slices out the relevant row of the matrix under the hood.
```@docs
Flux.onehotbatch
```

20
docs/src/datasets.md Normal file
View File

@ -0,0 +1,20 @@
# Datasets
Flux includes several standard machine learning datasets.
```@docs
Flux.Data.Iris.features()
Flux.Data.Iris.labels()
Flux.Data.MNIST.images()
Flux.Data.MNIST.labels()
Flux.Data.FashionMNIST.images()
Flux.Data.FashionMNIST.labels()
Flux.Data.CMUDict.phones()
Flux.Data.CMUDict.symbols()
Flux.Data.CMUDict.rawdict()
Flux.Data.CMUDict.cmudict()
Flux.Data.Sentiment.train()
Flux.Data.Sentiment.test()
Flux.Data.Sentiment.dev()
```

View File

@ -220,7 +220,7 @@ Flux.@functor Affine
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advacned.md).
For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advanced.md).
## Utility functions
@ -240,5 +240,5 @@ Currently limited to the following layers:
- `MeanPool`
```@docs
outdims
Flux.outdims
```

View File

@ -32,6 +32,7 @@ RNN
LSTM
GRU
Flux.Recur
Flux.reset!
```
## Other General Purpose Layers
@ -49,20 +50,22 @@ SkipConnection
These layers don't affect the structure of the network but may improve training times or reduce overfitting.
```@docs
Flux.normalise
BatchNorm
Dropout
Flux.dropout
Dropout
AlphaDropout
LayerNorm
InstanceNorm
GroupNorm
```
### Testmode
Many normalisation layers behave differently under training and inference (testing). By default, Flux will automatically determine when a layer evaluation is part of training or inference. Still, depending on your use case, it may be helpful to manually specify when these layers should be treated as being trained or not. For this, Flux provides `testmode!`. When called on a model (e.g. a layer or chain of layers), this function will place the model into the mode specified.
Many normalisation layers behave differently under training and inference (testing). By default, Flux will automatically determine when a layer evaluation is part of training or inference. Still, depending on your use case, it may be helpful to manually specify when these layers should be treated as being trained or not. For this, Flux provides `Flux.testmode!`. When called on a model (e.g. a layer or chain of layers), this function will place the model into the mode specified.
```@docs
testmode!
Flux.testmode!
trainmode!
```

View File

@ -64,3 +64,7 @@ julia> activations(c, rand(10))
julia> sum(norm, ans)
2.1166067f0
```
```@docs
Flux.activations
```

View File

@ -52,6 +52,7 @@ Momentum
Nesterov
RMSProp
ADAM
RADAM
AdaMax
ADAGrad
ADADelta

View File

@ -32,6 +32,7 @@ Flux.train!(loss, ps, data, opt)
```
The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
For a list of all built-in loss functions, check out the [layer reference](../models/layers.md).
At first glance it may seem strange that the model that we want to train is not part of the input arguments of `Flux.train!` too. However the target of the optimizer is not the model itself, but the objective function that represents the departure between modelled and observed data. In other words, the model is implicitly defined in the objective function, and there is no need to give it explicitly. Passing the objective function instead of the model and a cost function separately provides more flexibility, and the possibility of optimizing the calculations.
@ -94,6 +95,10 @@ julia> @epochs 2 Flux.train!(...)
# Train for two epochs
```
```@docs
Flux.@epochs
```
## Callbacks
`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example:

49
docs/src/utilities.md Normal file
View File

@ -0,0 +1,49 @@
# Utility Functions
Flux contains some utility functions for working with data; these functions
help create inputs for your models or batch your dataset.
Other functions can be used to initialize your layers or to regularly execute
callback functions.
## Working with Data
```@docs
Flux.unsqueeze
Flux.stack
Flux.unstack
Flux.chunk
Flux.frequencies
Flux.batch
Flux.batchseq
Base.rpad(v::AbstractVector, n::Integer, p)
```
## Layer Initialization
These are primarily useful if you are planning to write your own layers.
Flux initializes convolutional layers and recurrent cells with `glorot_uniform`
by default.
To change the default on an applicable layer, pass the desired function with the
`init` keyword. For example:
```jldoctest; setup = :(using Flux)
julia> conv = Conv((3, 3), 1 => 8, relu; init=Flux.glorot_normal)
Conv((3, 3), 1=>8, relu)
```
```@docs
Flux.glorot_uniform
Flux.glorot_normal
```
## Model Abstraction
```@docs
Flux.destructure
```
## Callback Helpers
```@docs
Flux.throttle
Flux.stop
```

View File

@ -24,18 +24,35 @@ function load()
end
end
"""
phones()
Return a `Vector` containing the phones used in the CMU Pronouncing Dictionary.
"""
function phones()
load()
Symbol.(first.(split.(split(read(deps("cmudict", "cmudict.phones"),String),
"\n", keepempty = false), "\t")))
end
"""
symbols()
Return a `Vector` containing the symbols used in the CMU Pronouncing Dictionary.
A symbol is a phone with optional auxiliary symbols, indicating for example the
amount of stress on the phone.
"""
function symbols()
load()
Symbol.(split(read(deps("cmudict", "cmudict.symbols"),String),
"\n", keepempty = false))
end
"""
rawdict()
Return the unfiltered CMU Pronouncing Dictionary.
"""
function rawdict()
load()
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
@ -44,6 +61,14 @@ end
validword(s) = isascii(s) && occursin(r"^[\w\-\.]+$", s)
"""
cmudict()
Return a filtered CMU Pronouncing Dictionary.
It is filtered so each word contains only ASCII characters and a combination of
word characters (as determined by the regex engine using `\\w`), '-' and '.'.
"""
cmudict() = filter(p -> validword(p.first), rawdict())
alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.']

View File

@ -33,9 +33,10 @@ const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte")
Load the Fashion-MNIST images.
Each image is a 28×28 array of `Gray` colour values (see Colors.jl).
Each image is a 28×28 array of `Gray` colour values
(see [Colors.jl](https://github.com/JuliaGraphics/Colors.jl)).
Returns the 60,000 training images by default; pass `:test` to retreive the
Return the 60,000 training images by default; pass `:test` to retrieve the
10,000 test images.
"""
function images(set = :train)
@ -49,10 +50,10 @@ end
labels()
labels(:test)
Load the labels corresponding to each of the images returned from `images()`.
Load the labels corresponding to each of the images returned from [`images()`](@ref).
Each label is a number from 0-9.
Returns the 60,000 training labels by default; pass `:test` to retreive the
Return the 60,000 training labels by default; pass `:test` to retrieve the
10,000 test labels.
"""
function labels(set = :train)

View File

@ -4,11 +4,10 @@ Fisher's classic iris dataset.
Measurements from 3 different species of iris: setosa, versicolor and
virginica. There are 50 examples of each species.
There are 4 measurements for each example: sepal length, sepal width, petal
length and petal width. The measurements are in centimeters.
There are 4 measurements for each example: sepal length, sepal width,
petal length and petal width. The measurements are in centimeters.
The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris).
"""
module Iris
@ -33,9 +32,7 @@ end
Get the labels of the iris dataset, a 150 element array of strings listing the
species of each example.
```jldoctest
julia> using Flux
```jldoctest; setup = :(Flux.Data.Iris.load())
julia> labels = Flux.Data.Iris.labels();
julia> summary(labels)
@ -58,9 +55,7 @@ Get the features of the iris dataset. This is a 4x150 matrix of Float64
elements. It has a row for each feature (sepal length, sepal width,
petal length, petal width) and a column for each example.
```jldoctest
julia> using Flux
```jldoctest; setup = :(Flux.Data.Iris.load())
julia> features = Flux.Data.Iris.features();
julia> summary(features)

View File

@ -83,9 +83,10 @@ getfeatures(io::IO, index::Integer) = vec(getimage(io, index))
Load the MNIST images.
Each image is a 28×28 array of `Gray` colour values (see Colors.jl).
Each image is a 28×28 array of `Gray` colour values
(see [Colors.jl](https://github.com/JuliaGraphics/Colors.jl)).
Returns the 60,000 training images by default; pass `:test` to retreive the
Return the 60,000 training images by default; pass `:test` to retrieve the
10,000 test images.
"""
function images(set = :train)
@ -99,10 +100,10 @@ end
labels()
labels(:test)
Load the labels corresponding to each of the images returned from `images()`.
Load the labels corresponding to each of the images returned from [`images()`](@ref).
Each label is a number from 0-9.
Returns the 60,000 training labels by default; pass `:test` to retreive the
Return the 60,000 training labels by default; pass `:test` to retrieve the
10,000 test labels.
"""
function labels(set = :train)

View File

@ -1,3 +1,4 @@
"Stanford Sentiment Treebank dataset."
module Sentiment
using ZipFile
@ -39,8 +40,28 @@ function gettrees(name)
return parsetree.(ss)
end
"""
train()
Return the train split of the Stanford Sentiment Treebank.
The data is in [treebank](https://en.wikipedia.org/wiki/Treebank) format.
"""
train() = gettrees("train")
"""
test()
Return the test split of the Stanford Sentiment Treebank.
The data is in [treebank](https://en.wikipedia.org/wiki/Treebank) format.
"""
test() = gettrees("test")
"""
dev()
Return the dev split of the Stanford Sentiment Treebank.
The data is in [treebank](https://en.wikipedia.org/wiki/Treebank) format.
"""
dev() = gettrees("dev")
end

View File

@ -4,17 +4,23 @@
Chain multiple layers / functions together, so that they are called in sequence
on a given input.
```julia
m = Chain(x -> x^2, x -> x+1)
m(5) == 26
m = Chain(Dense(10, 5), Dense(5, 2))
x = rand(10)
m(x) == m[2](m[1](x))
```
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
`m[1:3](x)` will calculate the output of the first three layers.
# Examples
```jldoctest
julia> m = Chain(x -> x^2, x -> x+1);
julia> m(5) == 26
true
julia> m = Chain(Dense(10, 5), Dense(5, 2));
julia> x = rand(10);
julia> m(x) == m[2](m[1](x))
true
```
"""
struct Chain{T<:Tuple}
layers::T
@ -60,6 +66,7 @@ outdims(c::Chain, isize) = foldl(∘, map(l -> (x -> outdims(l, x)), c.layers))(
# only slightly changed to better handle interaction with Zygote @dsweber2
"""
activations(c::Chain, input)
Calculate the forward results of each layers in Chain `c` with `input` as model input.
"""
function activations(c::Chain, input)
@ -78,22 +85,22 @@ extraChain(::Tuple{}, x) = ()
"""
Dense(in::Integer, out::Integer, σ = identity)
Creates a traditional `Dense` layer with parameters `W` and `b`.
Create a traditional `Dense` layer with parameters `W` and `b`.
y = σ.(W * x .+ b)
The input `x` must be a vector of length `in`, or a batch of vectors represented
as an `in × N` matrix. The out `y` will be a vector or batch of length `out`.
```julia
# Examples
```jldoctest; setup = :(using Random; Random.seed!(0))
julia> d = Dense(5, 2)
Dense(5, 2)
julia> d(rand(5))
Array{Float64,1}:
0.00257447
-0.00449443
```
2-element Array{Float32,1}:
-0.16210233
0.12311903```
"""
struct Dense{F,S,T}
W::S
@ -145,7 +152,7 @@ outdims(l::Dense, isize) = (size(l.W)[1],)
"""
Diagonal(in::Integer)
Creates an element-wise linear transformation layer with learnable
Create an element-wise linear transformation layer with learnable
vectors `α` and `β`:
y = α .* x .+ β
@ -176,18 +183,11 @@ outdims(l::Diagonal, isize) = (length(l.α),)
"""
Maxout(over)
`Maxout` is a neural network layer, which has a number of internal layers,
which all have the same input, and the maxout returns the elementwise maximium
of the internal layers' outputs.
The [Maxout](https://arxiv.org/pdf/1302.4389.pdf) layer has a number of
internal layers which all receive the same input. It returns the elementwise
maximum of the internal layers' outputs.
Maxout over linear dense layers satisfies the univeral approximation theorem.
Reference:
Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, and Yoshua Bengio.
2013. Maxout networks.
In Proceedings of the 30th International Conference on International Conference on Machine Learning - Volume 28 (ICML'13),
Sanjoy Dasgupta and David McAllester (Eds.), Vol. 28. JMLR.org III-1319-III-1327.
https://arxiv.org/pdf/1302.4389.pdf
"""
struct Maxout{FS<:Tuple}
over::FS
@ -196,17 +196,18 @@ end
"""
Maxout(f, n_alts)
Constructs a Maxout layer over `n_alts` instances of the layer given by `f`.
The function takes no arguement and should return some callable layer.
Conventionally this is a linear dense layer.
Construct a Maxout layer over `n_alts` instances of the layer given by `f`.
The function takes no arguments and should return some callable layer.
Conventionally, this is a linear dense layer.
For example the following example which
will construct a `Maxout` layer over 4 internal dense linear layers,
each identical in structure (784 inputs, 128 outputs).
# Examples
This constructs a `Maxout` layer over 4 internal dense linear layers, each
identical in structure (784 inputs, 128 outputs):
```julia
insize = 784
outsize = 128
Maxout(()->Dense(insize, outsize), 4)
insize = 784
outsize = 128
Maxout(()->Dense(insize, outsize), 4)
```
"""
function Maxout(f, n_alts)
@ -223,16 +224,18 @@ end
outdims(l::Maxout, isize) = outdims(first(l.over), isize)
"""
SkipConnection(layers, connection)
SkipConnection(layer, connection)
Creates a Skip Connection, of a layer or `Chain` of consecutive layers
plus a shortcut connection. The connection function will combine the result of the layers
with the original input, to give the final output.
Create a skip connection which consists of a layer or `Chain` of consecutive
layers and a shortcut connection linking the block's input to the output
through a user-supplied 2-argument callable. The first argument to the callable
will be propagated through the given `layer` while the second is the unchanged,
"skipped" input.
The simplest 'ResNet'-type connection is just `SkipConnection(layer, +)`,
The simplest "ResNet"-type connection is just `SkipConnection(layer, +)`,
and requires the output of the layers to be the same shape as the input.
Here is a more complicated example:
```
```julia
m = Conv((3,3), 4=>7, pad=(1,1))
x = ones(5,5,4,10);
size(m(x)) == (5, 5, 7, 10)

View File

@ -8,25 +8,26 @@ _convtransoutdims(isize, ksize, ssize, dsize, pad) = (isize .- 1).*ssize .+ 1 .+
expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)
"""
Conv(size, in=>out)
Conv(size, in=>out, relu)
Conv(size, in => out, σ = identity; init = glorot_uniform,
stride = 1, pad = 0, dilation = 1)
Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Example: Applying Conv layer to a 1-channel input using a 2x2 window size,
giving us a 16-channel output. Output is activated with ReLU.
size = (2,2)
in = 1
out = 16
Conv((2, 2), 1=>16, relu)
Data should be stored in WHCN order (width, height, # channels, batch size).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
# Examples
Apply a `Conv` layer to a 1-channel input using a 2×2 window size, giving us a
16-channel output. Output is activated with ReLU.
```julia
size = (2,2)
in = 1
out = 16
Conv(size, in => out, relu)
```
"""
struct Conv{N,M,F,A,V}
σ::F
@ -76,8 +77,8 @@ end
"""
outdims(l::Conv, isize::Tuple)
Calculate the output dimensions given the input dimensions, `isize`.
Batch size and channel size are ignored as per `NNlib.jl`.
Calculate the output dimensions given the input dimensions `isize`.
Batch size and channel size are ignored as per [NNlib.jl](https://github.com/FluxML/NNlib.jl).
```julia
m = Conv((3, 3), 3 => 16)
@ -89,17 +90,15 @@ outdims(l::Conv, isize) =
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
"""
ConvTranspose(size, in=>out)
ConvTranspose(size, in=>out, relu)
ConvTranspose(size, in => out, σ = identity; init = glorot_uniform,
stride = 1, pad = 0, dilation = 1)
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order (width, height, # channels, # batches).
Data should be stored in WHCN order (width, height, # channels, batch size).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct ConvTranspose{N,M,F,A,V}
σ::F
@ -165,18 +164,16 @@ end
outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad)
"""
DepthwiseConv(size, in=>out)
DepthwiseConv(size, in=>out, relu)
DepthwiseConv(size, in => out, σ = identity; init = glorot_uniform,
stride = 1, pad = 0, dilation = 1)
Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Note that `out` must be an integer multiple of `in`.
Data should be stored in WHCN order (width, height, # channels, # batches).
Data should be stored in WHCN order (width, height, # channels, batch size).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct DepthwiseConv{N,M,F,A,V}
σ::F
@ -233,25 +230,26 @@ outdims(l::DepthwiseConv, isize) =
output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
"""
CrossCor(size, in=>out)
CrossCor(size, in=>out, relu)
CrossCor(size, in => out, σ = identity; init = glorot_uniform,
stride = 1, pad = 0, dilation = 1)
Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Example: Applying CrossCor layer to a 1-channel input using a 2x2 window size,
giving us a 16-channel output. Output is activated with ReLU.
size = (2,2)
in = 1
out = 16
CrossCor((2, 2), 1=>16, relu)
Data should be stored in WHCN order (width, height, # channels, # batches).
Data should be stored in WHCN order (width, height, # channels, batch size).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
# Examples
Apply a `CrossCor` layer to a 1-channel input using a 2×2 window size, giving us a
16-channel output. Output is activated with ReLU.
```julia
size = (2,2)
in = 1
out = 16
CrossCor((2, 2), 1=>16, relu)
```
"""
struct CrossCor{N,M,F,A,V}
σ::F
@ -357,11 +355,9 @@ function Base.show(io::IO, g::GlobalMeanPool)
end
"""
MaxPool(k)
MaxPool(k; pad = 0, stride = k)
Max pooling layer. `k` stands for the size of the window for each dimension of the input.
Takes the keyword arguments `pad` and `stride`.
Max pooling layer. `k` is the size of the window for each dimension of the input.
"""
struct MaxPool{N,M}
k::NTuple{N,Int}
@ -388,11 +384,9 @@ end
outdims(l::MaxPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
"""
MeanPool(k)
MeanPool(k; pad = 0, stride = k)
Mean pooling layer. `k` stands for the size of the window for each dimension of the input.
Takes the keyword arguments `pad` and `stride`.
Mean pooling layer. `k` is the size of the window for each dimension of the input.
"""
struct MeanPool{N,M}
k::NTuple{N,Int}

View File

@ -10,14 +10,14 @@ _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(s
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
"""
dropout(p, dims = :)
dropout(x, p; dims = :)
Dropout function. For each input, either sets that input to `0` (with probability
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specify the unbroadcasted
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
used as a regularisation, i.e. it reduces overfitting during training.
The dropout function. For each input, either sets that input to `0` (with probability
`p`) or scales it by `1 / (1 - p)`. `dims` specifies the unbroadcasted dimensions,
e.g. `dims=1` applies dropout along columns and `dims=2` along rows.
This is used as a regularisation, i.e. it reduces overfitting during training.
See also [`Dropout`](@ref).
See also the [`Dropout`](@ref) layer.
"""
dropout(x, p; dims = :) = x
@ -30,9 +30,9 @@ end
"""
Dropout(p, dims = :)
A Dropout layer. In the forward pass, applies the [`dropout`](@ref) function on the input.
Dropout layer. In the forward pass, apply the [`Flux.dropout`](@ref) function on the input.
Does nothing to the input once [`testmode!`](@ref) is true.
Does nothing to the input once [`Flux.testmode!`](@ref) is `true`.
"""
mutable struct Dropout{F,D}
p::F
@ -65,9 +65,10 @@ end
"""
AlphaDropout(p)
A dropout layer. It is used in Self-Normalizing Neural Networks.
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
A dropout layer. Used in
[Self-Normalizing Neural Networks](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf).
The AlphaDropout layer ensures that mean and variance of activations
remain the same as before.
Does nothing to the input once [`testmode!`](@ref) is true.
"""
@ -100,8 +101,8 @@ testmode!(m::AlphaDropout, mode = true) =
LayerNorm(h::Integer)
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
used with recurrent hidden states of size `h`. Normalises the mean/stddev of
each input before applying a per-neuron gain/bias.
used with recurrent hidden states of size `h`. Normalises the mean and standard
deviation of each input before applying a per-neuron gain/bias.
"""
struct LayerNorm{T}
diag::Diagonal{T}
@ -123,8 +124,8 @@ end
initβ = zeros, initγ = ones,
ϵ = 1e-8, momentum = .1)
Batch Normalization layer. The `channels` input should be the size of the
channel dimension in your data (see below).
[Batch Normalization](https://arxiv.org/pdf/1502.03167.pdf) layer.
`channels` should be the size of the channel dimension in your data (see below).
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
a batch of feature vectors this is just the data dimension, for `WHCN` images
@ -136,10 +137,7 @@ per-channel `bias` and `scale` parameters).
Use [`testmode!`](@ref) during inference.
See [Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
Example:
# Examples
```julia
m = Chain(
Dense(28^2, 64),
@ -213,37 +211,6 @@ function Base.show(io::IO, l::BatchNorm)
print(io, ")")
end
"""
InstanceNorm(channels::Integer, σ = identity;
initβ = zeros, initγ = ones,
ϵ = 1e-8, momentum = .1)
Instance Normalization layer. The `channels` input should be the size of the
channel dimension in your data (see below).
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
a batch of feature vectors this is just the data dimension, for `WHCN` images
it's the usual channel dimension.)
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
shifts them to have a new mean and variance (corresponding to the learnable,
per-channel `bias` and `scale` parameters).
Use [`testmode!`](@ref) during inference.
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
Example:
```julia
m = Chain(
Dense(28^2, 64),
InstanceNorm(64, relu),
Dense(64, 10),
InstanceNorm(10),
softmax)
```
"""
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
mutable struct InstanceNorm{F,V,W,N}
@ -258,6 +225,34 @@ mutable struct InstanceNorm{F,V,W,N}
end
# TODO: deprecate in v0.11
"""
InstanceNorm(channels::Integer, σ = identity;
initβ = zeros, initγ = ones,
ϵ = 1e-8, momentum = .1)
[Instance Normalization](https://arxiv.org/abs/1607.08022) layer.
`channels` should be the size of the channel dimension in your data (see below).
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
a batch of feature vectors this is just the data dimension, for `WHCN` images
it's the usual channel dimension.)
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
shifts them to have a new mean and variance (corresponding to the learnable,
per-channel `bias` and `scale` parameters).
Use [`testmode!`](@ref) during inference.
# Examples
```julia
m = Chain(
Dense(28^2, 64),
InstanceNorm(64, relu),
Dense(64, 10),
InstanceNorm(10),
softmax)
```
"""
InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) = InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
InstanceNorm(chs::Integer, λ = identity;
@ -316,28 +311,27 @@ function Base.show(io::IO, l::InstanceNorm)
end
"""
Group Normalization.
This layer can outperform Batch-Normalization and Instance-Normalization.
GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
ϵ = 1f-5, momentum = 0.1f0)
``chs`` is the number of channels, the channel dimension of your input.
For an array of N dimensions, the (N-1)th index is the channel dimension.
[Group Normalization](https://arxiv.org/pdf/1803.08494.pdf) layer.
This layer can outperform Batch Normalization and Instance Normalization.
``G`` is the number of groups along which the statistics would be computed.
`chs` is the number of channels, the channel dimension of your input.
For an array of N dimensions, the `N-1`th index is the channel dimension.
`G` is the number of groups along which the statistics are computed.
The number of channels must be an integer multiple of the number of groups.
Use [`testmode!`](@ref) during inference.
Example:
```
# Examples
```julia
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used
GroupNorm(32,16))
# 32 channels, 16 groups (G = 16), thus 2 channels per group used
```
Link : https://arxiv.org/pdf/1803.08494.pdf
"""
mutable struct GroupNorm{F,V,W,N,T}
G::T # number of groups

View File

@ -12,10 +12,10 @@ in the background. `cell` should be a model of the form:
h, y = cell(h, x...)
For example, here's a recurrent network that keeps a running total of its inputs.
For example, here's a recurrent network that keeps a running total of its inputs:
```julia
accum(h, x) = (h+x, x)
accum(h, x) = (h + x, x)
rnn = Flux.Recur(accum, 0)
rnn(2) # 2
rnn(3) # 3
@ -47,9 +47,10 @@ Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
Reset the hidden state of a recurrent layer back to its original value.
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
rnn.state = hidden(rnn.cell)
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to:
```julia
rnn.state = hidden(rnn.cell)
```
"""
reset!(m::Recur) = (m.state = m.init)
reset!(m) = foreach(reset!, functor(m)[1])
@ -135,8 +136,8 @@ Base.show(io::IO, l::LSTMCell) =
"""
LSTM(in::Integer, out::Integer)
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences.
[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)
recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
@ -176,8 +177,8 @@ Base.show(io::IO, l::GRUCell) =
"""
GRU(in::Integer, out::Integer)
Gated Recurrent Unit layer. Behaves like an RNN but generally
exhibits a longer memory span over sequences.
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078) layer. Behaves like an
RNN but generally exhibits a longer memory span over sequences.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.

View File

@ -2,7 +2,8 @@
"""
mae(, y)
Return the mean of absolute error `sum(abs.(ŷ .- y)) / length(y)`
Return the mean of absolute error; calculated as
`sum(abs.(ŷ .- y)) / length(y)`.
"""
mae(, y) = sum(abs.( .- y)) * 1 // length(y)
@ -10,7 +11,14 @@ mae(ŷ, y) = sum(abs.(ŷ .- y)) * 1 // length(y)
"""
mse(, y)
Return the mean squared error `sum((ŷ .- y).^2) / length(y)`.
Return the mean squared error between and y; calculated as
`sum((ŷ .- y).^2) / length(y)`.
# Examples
```jldoctest
julia> Flux.mse([0, 2], [1, 1])
1//1
```
"""
mse(, y) = sum(( .- y).^2) * 1 // length(y)
@ -18,10 +26,11 @@ mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
"""
msle(, y; ϵ=eps(eltype()))
Returns the mean of the squared logarithmic errors `sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) / length(y)`.
Return the mean of the squared logarithmic errors; calculated as
`sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) / length(y)`.
The `ϵ` term provides numerical stability.
This error penalizes an under-predicted estimate greater than an over-predicted estimate.
Penalizes an under-predicted estimate greater than an over-predicted estimate.
"""
msle(, y; ϵ=eps(eltype())) = sum((log.( .+ ϵ) .- log.(y .+ ϵ)).^2) * 1 // length(y)
@ -30,13 +39,12 @@ msle(ŷ, y; ϵ=eps(eltype(ŷ))) = sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) *
"""
huber_loss(, y; δ=1.0)
Computes the mean of the Huber loss given the prediction `` and true values `y`. By default, δ is set to 1.0.
Return the mean of the [Huber loss](https://en.wikipedia.org/wiki/Huber_loss)
given the prediction `` and true values `y`.
| 0.5*| - y|, for | - y| <= δ
Hubber loss = |
| δ*(| - y| - 0.5*δ), otherwise
[`Huber Loss`](https://en.wikipedia.org/wiki/Huber_loss).
| 0.5 * | - y|, for | - y| <= δ
Huber loss = |
| δ * (| - y| - 0.5 * δ), otherwise
"""
function huber_loss(, y; δ=eltype()(1))
abs_error = abs.( .- y)
@ -58,22 +66,40 @@ function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Abstr
end
"""
crossentropy(, y; weight=1)
crossentropy(, y; weight = nothing)
Return the crossentropy computed as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`.
Return the cross entropy between the given probability distributions;
calculated as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`.
See also [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref).
`weight` can be `Nothing`, a `Number` or an `AbstractVector`.
`weight=nothing` acts like `weight=1` but is faster.
See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
# Examples
```jldoctest
julia> Flux.crossentropy(softmax([-1.1491, 0.8619, 0.3127]), [1, 1, 0])
3.085467254747739
```
"""
crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _crossentropy(, y, weight)
"""
logitcrossentropy(, y; weight=1)
logitcrossentropy(, y; weight = 1)
Return the crossentropy computed after a [softmax](@ref) operation:
Return the crossentropy computed after a [`Flux.logsoftmax`](@ref) operation;
calculated as `-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)`.
-sum(y .* logsoftmax() .* weight) / size(y, 2)
`logitcrossentropy(ŷ, y)` is mathematically equivalent to
[`Flux.crossentropy(softmax(log(ŷ)), y)`](@ref) but it is more numerically stable.
See also [`crossentropy`](@ref), [`binarycrossentropy`](@ref).
See also: [`Flux.crossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
# Examples
```jldoctest
julia> Flux.logitcrossentropy([-1.1491, 0.8619, 0.3127], [1, 1, 0])
3.085467254747738
```
"""
function logitcrossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* logsoftmax() .* weight) * 1 // size(y, 2)
@ -82,9 +108,20 @@ end
"""
binarycrossentropy(, y; ϵ=eps())
Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability.
Return ``-y*\\log( + ϵ) - (1-y)*\\log(1- + ϵ)``. The `ϵ` term provides numerical stability.
Typically, the prediction `` is given by the output of a [`sigmoid`](@ref) activation.
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
# Examples
```jldoctest
julia> Flux.binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0])
3-element Array{Float64,1}:
1.424397097347566
0.35231664672364077
0.8616703662235441
```
"""
binarycrossentropy(, y; ϵ=eps()) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ)
@ -94,10 +131,19 @@ CuArrays.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1
"""
logitbinarycrossentropy(ŷ, y)
`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(ŷ), y)`
but it is more numerically stable.
`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to
[`Flux.binarycrossentropy(σ(log(ŷ)), y)`](@ref) but it is more numerically stable.
See also [`binarycrossentropy`](@ref), [`sigmoid`](@ref), [`logsigmoid`](@ref).
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref)
# Examples
```jldoctest
julia> Flux.logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0])
3-element Array{Float64,1}:
1.4243970973475661
0.35231664672364094
0.8616703662235443
```
"""
logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ()
@ -107,22 +153,23 @@ CuArrays.@cufunc logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ)
"""
normalise(x; dims=1)
Normalises `x` to mean 0 and standard deviation 1, across the dimensions given by `dims`. Defaults to normalising over columns.
Normalise `x` to mean 0 and standard deviation 1 across the dimensions given by `dims`.
Defaults to normalising over columns.
```julia-repl
```jldoctest
julia> a = reshape(collect(1:9), 3, 3)
3×3 Array{Int64,2}:
1 4 7
2 5 8
3 6 9
julia> normalise(a)
julia> Flux.normalise(a)
3×3 Array{Float64,2}:
-1.22474 -1.22474 -1.22474
0.0 0.0 0.0
1.22474 1.22474 1.22474
julia> normalise(a, dims=2)
julia> Flux.normalise(a, dims=2)
3×3 Array{Float64,2}:
-1.22474 0.0 1.22474
-1.22474 0.0 1.22474
@ -138,10 +185,14 @@ end
"""
kldivergence(, y)
KLDivergence is a measure of how much one probability distribution is different from the other.
It is always non-negative and zero only when both the distributions are equal everywhere.
Return the
[Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence)
between the given probability distributions.
[KL Divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence).
KL divergence is a measure of how much one probability distribution is different
from the other.
It is always non-negative and zero only when both the distributions are equal
everywhere.
"""
function kldivergence(, y)
entropy = sum(y .* log.(y)) * 1 //size(y,2)
@ -152,59 +203,60 @@ end
"""
poisson(, y)
Poisson loss function is a measure of how the predicted distribution diverges from the expected distribution.
Returns `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`
Return how much the predicted distribution `` diverges from the expected Poisson
distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`.
[Poisson Loss](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
"""
poisson(, y) = sum( .- y .* log.()) * 1 // size(y,2)
"""
hinge(, y)
Measures the loss given the prediction `` and true labels `y` (containing 1 or -1).
Returns `sum((max.(0, 1 .- ŷ .* y))) / size(y, 2)`
Return the [hinge loss](https://en.wikipedia.org/wiki/Hinge_loss) given the
prediction `` and true labels `y` (containing 1 or -1); calculated as
`sum(max.(0, 1 .- ŷ .* y)) / size(y, 2)`.
[Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss)
See also [`squared_hinge`](@ref).
See also: [`squared_hinge`](@ref)
"""
hinge(, y) = sum(max.(0, 1 .- .* y)) * 1 // size(y, 2)
"""
squared_hinge(, y)
Computes squared hinge loss given the prediction `` and true labels `y` (conatining 1 or -1).
Returns `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`
Return the squared hinge loss given the prediction `` and true labels `y`
(containing 1 or -1); calculated as `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`.
See also [`hinge`](@ref).
See also: [`hinge`](@ref)
"""
squared_hinge(, y) = sum((max.(0, 1 .- .* y)).^2) * 1 // size(y, 2)
"""
dice_coeff_loss(, y; smooth=1)
Loss function used in Image Segmentation. Calculates loss based on dice coefficient. Similar to F1_score.
Returns `1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth)`
[V-Net: Fully Convolutional Neural Networks forVolumetric Medical Image Segmentation](https://arxiv.org/pdf/1606.04797v1.pdf)
Return a loss based on the dice coefficient.
Used in the [V-Net](https://arxiv.org/pdf/1606.04797v1.pdf) image segmentation
architecture.
Similar to the F1_score. Calculated as:
1 - 2*sum(| .* y| + smooth) / (sum(.^2) + sum(y.^2) + smooth)`
"""
dice_coeff_loss(, y; smooth=eltype()(1.0)) = 1 - (2*sum(y .* ) + smooth) / (sum(y.^2) + sum(.^2) + smooth)
"""
tversky_loss(, y; β=0.7)
Used with imbalanced data to give more weightage to False negatives.
Return the [Tversky loss](https://arxiv.org/pdf/1706.05721.pdf).
Used with imbalanced data to give more weight to false negatives.
Larger β weigh recall higher than precision (by placing more emphasis on false negatives)
Returns `1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)`
[Tversky loss function for image segmentation using 3D fully convolutional deep networks](https://arxiv.org/pdf/1706.05721.pdf)
Calculated as:
1 - sum(|y .* | + 1) / (sum(y .* + β*(1 .- y) .* + (1 - β)*y .* (1 .- )) + 1)
"""
tversky_loss(, y; β=eltype()(0.7)) = 1 - (sum(y .* ) + 1) / (sum(y .* + β*(1 .- y) .* + (1 - β)*y .* (1 .- )) + 1)
"""
flatten(x::AbstractArray)
Transforms (w,h,c,b)-shaped input into (w x h x c,b)-shaped output,
Transform (w, h, c, b)-shaped input into (w × h × c, b)-shaped output
by linearizing all values for each element in the batch.
"""
function flatten(x::AbstractArray)

View File

@ -45,22 +45,20 @@ cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.d
"""
onehot(l, labels[, unk])
Create an [`OneHotVector`](@ref) wtih `l`-th element be `true` based on possible `labels` set.
If `unk` is given, it retruns `onehot(unk, labels)` if the input label `l` is not find in `labels`; otherwise
it will error.
## Examples
Create a `OneHotVector` with its `l`-th element `true` based on the
possible set of `labels`.
If `unk` is given, return `onehot(unk, labels)` if the input label `l` is not found
in `labels`; otherwise it will error.
# Examples
```jldoctest
julia> using Flux: onehot
julia> onehot(:b, [:a, :b, :c])
julia> Flux.onehot(:b, [:a, :b, :c])
3-element Flux.OneHotVector:
0
1
0
julia> onehot(:c, [:a, :b, :c])
julia> Flux.onehot(:c, [:a, :b, :c])
3-element Flux.OneHotVector:
0
0
@ -82,15 +80,14 @@ end
"""
onehotbatch(ls, labels[, unk...])
Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `labels` set, returns the
`onehot(unk, labels)` if given labels `ls` is not found in set `labels`.
## Examples
Create a `OneHotMatrix` with a batch of labels based on the
possible set of `labels`.
If `unk` is given, return [`onehot(unk, labels)`](@ref) if one of the input
labels `ls` is not found in `labels`; otherwise it will error.
# Examples
```jldoctest
julia> using Flux: onehotbatch
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
julia> Flux.onehotbatch([:b, :a, :b], [:a, :b, :c])
3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
0 1 0
1 0 1
@ -107,13 +104,12 @@ Base.argmax(xs::OneHotVector) = xs.ix
Inverse operations of [`onehot`](@ref).
# Examples
```jldoctest
julia> using Flux: onecold
julia> onecold([true, false, false], [:a, :b, :c])
julia> Flux.onecold([true, false, false], [:a, :b, :c])
:a
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c])
:c
```
"""

View File

@ -6,19 +6,20 @@ const ϵ = 1e-8
# TODO: should use weak refs
"""
Descent(η)
Descent(η = 0.1)
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`
## Parameters
- Learning Rate (η): The amount by which the gradients are discounted before updating the weights. Defaults to `0.1`.
# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
the weights.
## Example
```julia-repl
opt = Descent() # uses default η (0.1)
# Examples
```julia
opt = Descent()
opt = Descent(0.3) # use provided η
opt = Descent(0.3)
ps = params(model)
@ -40,17 +41,19 @@ function apply!(o::Descent, x, Δ)
end
"""
Momentum(η, ρ)
Momentum(η = 0.01, ρ = 0.9)
Gradient descent with learning rate `η` and momentum `ρ`.
Gradient descent optimizer with learning rate `η` and momentum `ρ`.
## Parameters
- Learning Rate (`η`): Amount by which gradients are discounted before updating the weights. Defaults to `0.01`.
- Momentum (`ρ`): Parameter that accelerates descent in the relevant direction and dampens oscillations. Defaults to `0.9`.
# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
the weights.
- Momentum (`ρ`): Controls the acceleration of gradient descent in the
prominent direction, in effect dampening oscillations.
## Examples
# Examples
```julia
opt = Momentum() # uses defaults of η = 0.01 and ρ = 0.9
opt = Momentum()
opt = Momentum(0.01, 0.99)
```
@ -71,17 +74,19 @@ function apply!(o::Momentum, x, Δ)
end
"""
Nesterov(η, ρ)
Nesterov(η = 0.001, ρ = 0.9)
Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
Gradient descent optimizer with learning rate `η` and Nesterov momentum `ρ`.
## Parameters
- Learning Rate (η): Amount by which the gradients are dicsounted berfore updating the weights. Defaults to `0.001`.
- Nesterov Momentum (ρ): Parameters controlling the amount of nesterov momentum to be applied. Defaults to `0.9`.
# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
the weights.
- Nesterov momentum (`ρ`): Controls the acceleration of gradient descent in the
prominent direction, in effect dampening oscillations.
## Examples
# Examples
```julia
opt = Nesterov() # uses defaults η = 0.001 and ρ = 0.9
opt = Nesterov()
opt = Nesterov(0.003, 0.95)
```
@ -103,23 +108,25 @@ function apply!(o::Nesterov, x, Δ)
end
"""
RMSProp(η, ρ)
RMSProp(η = 0.001, ρ = 0.9)
Implements the RMSProp algortihm. Often a good choice for recurrent networks. Parameters other than learning rate generally don't need tuning.
Optimizer using the
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
algorithm. Often a good choice for recurrent networks. Parameters other than learning rate
generally don't need tuning.
## Parameters
- Learning Rate (η): Defaults to `0.001`.
- Rho (ρ): Defaults to `0.9`.
# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
the weights.
- Momentum (`ρ`): Controls the acceleration of gradient descent in the
prominent direction, in effect dampening oscillations.
## Examples
# Examples
```julia
opt = RMSProp() # uses default η = 0.001 and ρ = 0.9
opt = RMSProp()
opt = RMSProp(0.002, 0.95)
```
## References
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
"""
mutable struct RMSProp
eta::Float64
@ -137,23 +144,22 @@ function apply!(o::RMSProp, x, Δ)
end
"""
ADAM(η, β::Tuple)
ADAM(η = 0.001, β::Tuple = (0.9, 0.999))
Implements the ADAM optimiser.
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
## Paramters
- Learning Rate (`η`): Defaults to `0.001`.
- Beta (`β::Tuple`): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
## Examples
# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
the weights.
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.
# Examples
```julia
opt = ADAM() # uses the default η = 0.001 and β = (0.9, 0.999)
opt = ADAM()
opt = ADAM(0.001, (0.9, 0.8))
```
## References
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
"""
mutable struct ADAM
eta::Float64
@ -174,24 +180,22 @@ function apply!(o::ADAM, x, Δ)
end
"""
RADAM(η, β::Tuple)
RADAM(η = 0.001, β::Tuple = (0.9, 0.999))
Implements the rectified ADAM optimizer.
[Rectified ADAM](https://arxiv.org/pdf/1908.03265v1.pdf) optimizer.
## Parameters
- Learning Rate (η): Defaults to `0.001`
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
## Examples
# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
the weights.
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.
# Examples
```julia
opt = RADAM() # uses the default η = 0.001 and β = (0.9, 0.999)
opt = RADAM()
opt = RADAM(0.001, (0.9, 0.8))
```
## References
[RADAM](https://arxiv.org/pdf/1908.03265v1.pdf) optimiser (Rectified ADAM).
"""
mutable struct RADAM
eta::Float64
@ -219,22 +223,22 @@ function apply!(o::RADAM, x, Δ)
end
"""
AdaMax(η, β::Tuple)
AdaMax(η = 0.001, β::Tuple = (0.9, 0.999))
Variant of ADAM based on -norm.
[AdaMax](https://arxiv.org/abs/1412.6980v9) is a variant of ADAM based on the -norm.
## Parameters
- Learning Rate (η): Defaults to `0.001`
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
the weights.
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.
## Examples
# Examples
```julia
opt = AdaMax() # uses default η and β
opt = AdaMax()
opt = AdaMax(0.001, (0.9, 0.995))
```
## References
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser.
"""
mutable struct AdaMax
eta::Float64
@ -255,23 +259,22 @@ function apply!(o::AdaMax, x, Δ)
end
"""
ADAGrad(η)
ADAGrad(η = 0.1)
Implements AdaGrad. It has parameter specific learning rates based on how frequently it is updated.
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimizer. It has
parameter specific learning rates based on how frequently it is updated.
Parameters don't need tuning.
## Parameters
- Learning Rate (η): Defaults to `0.1`
# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
the weights.
## Examples
# Examples
```julia
opt = ADAGrad() # uses default η = 0.1
opt = ADAGrad()
opt = ADAGrad(0.001)
```
## References
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning.
"""
mutable struct ADAGrad
eta::Float64
@ -288,21 +291,21 @@ function apply!(o::ADAGrad, x, Δ)
end
"""
ADADelta(ρ)
ADADelta(ρ = 0.9)
Version of ADAGrad that adapts learning rate based on a window of past gradient updates. Parameters don't need tuning.
[ADADelta](https://arxiv.org/abs/1212.5701) is a version of ADAGrad adapting its learning
rate based on a window of past gradient updates.
Parameters don't need tuning.
## Parameters
- Rho (ρ): Factor by which gradient is decayed at each time step. Defaults to `0.9`.
# Parameters
- Rho (`ρ`): Factor by which the gradient is decayed at each time step.
## Examples
# Examples
```julia
opt = ADADelta() # uses default ρ = 0.9
opt = ADADelta()
opt = ADADelta(0.89)
```
## References
[ADADelta](https://arxiv.org/abs/1212.5701) optimiser.
"""
mutable struct ADADelta
rho::Float64
@ -321,22 +324,23 @@ function apply!(o::ADADelta, x, Δ)
end
"""
AMSGrad(η, β::Tuple)
AMSGrad(η = 0.001, β::Tuple = (0.9, 0.999))
Implements AMSGrad version of the ADAM optimiser. Parameters don't need tuning.
The [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) version of the ADAM
optimiser. Parameters don't need tuning.
## Parameters
- Learning Rate (η): Defaults to `0.001`.
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
the weights.
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.
## Examples
# Examples
```julia
opt = AMSGrad() # uses default η and β
opt = AMSGrad()
opt = AMSGrad(0.001, (0.89, 0.995))
```
## References
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser.
"""
mutable struct AMSGrad
eta::Float64
@ -356,22 +360,23 @@ function apply!(o::AMSGrad, x, Δ)
end
"""
NADAM(η, β::Tuple)
NADAM(η = 0.001, β::Tuple = (0.9, 0.999))
Nesterov variant of ADAM. Parameters don't need tuning.
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) is a Nesterov variant of ADAM.
Parameters don't need tuning.
## Parameters
- Learning Rate (η): Defaults to `0.001`.
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
the weights.
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.
## Examples
# Examples
```julia
opt = NADAM() # uses default η and β
opt = NADAM()
opt = NADAM(0.002, (0.89, 0.995))
```
## References
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser.
"""
mutable struct NADAM
eta::Float64
@ -392,23 +397,24 @@ function apply!(o::NADAM, x, Δ)
end
"""
ADAMW(η, β::Tuple, decay)
ADAMW(η = 0.001, β::Tuple = (0.9, 0.999), decay = 0)
Variant of ADAM defined by fixing weight decay regularization.
[ADAMW](https://arxiv.org/abs/1711.05101) is a variant of ADAM fixing (as in repairing) its
weight decay regularization.
## Parameters
- Learning Rate (η): Defaults to `0.001`.
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to (0.9, 0.999).
- decay: Decay applied to weights during optimisation. Defaults to 0.
# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
the weights.
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.
- `decay`: Decay applied to weights during optimisation.
## Examples
# Examples
```julia
opt = ADAMW() # uses default η, β and decay
opt = ADAMW()
opt = ADAMW(0.001, (0.89, 0.995), 0.1)
```
## References
[ADAMW](https://arxiv.org/abs/1711.05101)
"""
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
Optimiser(ADAM(η, β), WeightDecay(decay))
@ -441,14 +447,13 @@ function apply!(o::Optimiser, x, Δ)
end
"""
InvDecay(γ)
InvDecay(γ = 0.001)
Applies inverse time decay to an optimiser, i.e., the effective step size at iteration `n` is `eta / (1 + γ * n)` where `eta` is the initial step size. The wrapped optimiser's step size is not modified.
Apply inverse time decay to an optimiser, so that the effective step size at
iteration `n` is `eta / (1 + γ * n)` where `eta` is the initial step size.
The wrapped optimiser's step size is not modified.
## Parameters
- gamma (γ): Defaults to `0.001`
## Example
# Examples
```julia
Optimiser(InvDecay(..), Opt(..))
```
@ -469,20 +474,24 @@ function apply!(o::InvDecay, x, Δ)
end
"""
ExpDecay(eta, decay, decay_step, clip)
ExpDecay(η = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4)
Discount the learning rate `eta` by a multiplicative factor `decay` every `decay_step` till a minimum of `clip`.
Discount the learning rate `η` by the factor `decay` every `decay_step` steps till
a minimum of `clip`.
## Parameters
- Learning Rate (eta): Defaults to `0.001`.
- decay: Factor by which the learning rate is discounted. Defaults to `0.1`.
- decay_step: Schedules decay operations by setting number of steps between two decay operations. Defaults to `1000`.
- clip: Minimum value of learning rate. Defaults to `1e-4`.
# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
the weights.
- `decay`: Factor by which the learning rate is discounted.
- `decay_step`: Schedule decay operations by setting the number of steps between
two decay operations.
- `clip`: Minimum value of learning rate.
## Example
# Examples
To apply exponential decay to an optimiser:
```julia
Optimiser(ExpDecay(..), Opt(..))
opt = Optimiser(ExpDecay(), ADAM())
```
"""
@ -507,12 +516,12 @@ function apply!(o::ExpDecay, x, Δ)
end
"""
WeightDecay(wd)
WeightDecay(wd = 0)
Decays the weight by `wd`
Decay weights by `wd`.
## Parameters
- weight decay (wd): 0
# Parameters
- Weight decay (`wd`)
"""
mutable struct WeightDecay
wd::Real

View File

@ -2,6 +2,16 @@ using Juno
import Zygote: Params, gradient
"""
update!(x, )
Update the array `x` according to `x .-= x̄`.
"""
function update!(x::AbstractArray, )
x .-=
end
"""
update!(opt, p, g)
update!(opt, ps::Params, gs)
@ -10,15 +20,7 @@ Perform an update step of the parameters `ps` (or the single parameter `p`)
according to optimizer `opt` and the gradients `gs` (the gradient `g`).
As a result, the parameters are mutated and the optimizer's internal state may change.
update!(x, )
Update the array `x` according to `x .-= x̄`.
"""
function update!(x::AbstractArray, )
x .-=
end
function update!(opt, x, )
x .-= apply!(opt, x, )
end
@ -41,11 +43,10 @@ struct StopException <: Exception end
stop()
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
This would trigger the train loop to stop and exit.
This will trigger the train loop to stop and exit.
# Examples
```julia
# Example callback:
cb = function ()
accuracy() > 0.9 && Flux.stop()
end
@ -58,19 +59,19 @@ end
"""
train!(loss, params, data, opt; cb)
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt`.
For each datapoint `d` in `data` compute the gradient of `loss(d...)` through
backpropagation and call the optimizer `opt`.
In case datapoints `d` are of numeric array type, assumes no splatting is needed
and computes the gradient of `loss(d)`.
In case datapoints `d` are of numeric array type, assume no splatting is needed
and compute the gradient of `loss(d)`.
Takes a callback as keyword argument `cb`. For example, this will print "training"
every 10 seconds:
A callback is given with the keyword argument `cb`. For example, this will print
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
train!(loss, params, data, opt,
cb = throttle(() -> println("training"), 10))
The callback can call `Flux.stop()` to interrupt the training loop.
The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
"""
@ -106,11 +107,12 @@ end
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
training in a REPL.
```julia
julia> @epochs 2 println("hello")
INFO: Epoch 1
# Examples
```jldoctest
julia> Flux.@epochs 2 println("hello")
[ Info: Epoch 1
hello
INFO: Epoch 2
[ Info: Epoch 2
hello
```
"""

View File

@ -1,10 +1,40 @@
# Arrays
nfan() = 1, 1 #fan_in, fan_out
nfan(n) = 1, n #A vector is treated as a n×1 matrix
nfan(n_out, n_in) = n_in, n_out #In case of Dense kernels: arranged as matrices
nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) #In case of convolution kernels
nfan() = 1, 1 # fan_in, fan_out
nfan(n) = 1, n # A vector is treated as a n×1 matrix
nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices
nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of convolution kernels
"""
glorot_uniform(dims...)
Return an `Array` of size `dims` containing random variables taken from a uniform
distribution in the interval ``[-x, x]``, where `x = sqrt(24 / sum(dims)) / 2`.
# Examples
```jldoctest; setup = :(using Random; Random.seed!(0))
julia> Flux.glorot_uniform(2, 3)
2×3 Array{Float32,2}:
0.601094 -0.57414 -0.814925
0.900868 0.805994 0.057514
```
"""
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
"""
glorot_normal(dims...)
Return an `Array` of size `dims` containing random variables taken from a normal
distribution with mean 0 and standard deviation `(2 / sum(dims))`.
# Examples
```jldoctest; setup = :(using Random; Random.seed!(0))
julia> Flux.glorot_normal(3, 2)
3×2 Array{Float32,2}:
0.429505 -0.0852891
0.523935 0.371009
-0.223261 0.188052
```
"""
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
ones(T::Type, dims...) = Base.ones(T, dims...)
@ -13,9 +43,81 @@ zeros(T::Type, dims...) = Base.zeros(T, dims...)
ones(dims...) = Base.ones(Float32, dims...)
zeros(dims...) = Base.zeros(Float32, dims...)
"""
unsqueeze(xs, dim)
Return `xs` reshaped into an `Array` one dimensionality higher than `xs`,
where `dim` indicates in which dimension `xs` is extended.
# Examples
```jldoctest
julia> xs = [[1, 2], [3, 4], [5, 6]]
3-element Array{Array{Int64,1},1}:
[1, 2]
[3, 4]
[5, 6]
julia> Flux.unsqueeze(xs, 1)
1×3 Array{Array{Int64,1},2}:
[1, 2] [3, 4] [5, 6]
julia> Flux.unsqueeze([1 2; 3 4], 2)
2×1×2 Array{Int64,3}:
[:, :, 1] =
1
3
[:, :, 2] =
2
4
```
"""
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
"""
stack(xs, dim)
Concatenate the given `Array` of `Array`s `xs` into a single `Array` along the
given dimension `dim`.
# Examples
```jldoctest
julia> xs = [[1, 2], [3, 4], [5, 6]]
3-element Array{Array{Int64,1},1}:
[1, 2]
[3, 4]
[5, 6]
julia> Flux.stack(xs, 1)
3×2 Array{Int64,2}:
1 2
3 4
5 6
julia> cat(xs, dims=1)
3-element Array{Array{Int64,1},1}:
[1, 2]
[3, 4]
[5, 6]
```
"""
stack(xs, dim) = cat(unsqueeze.(xs, dim)..., dims=dim)
"""
unstack(xs, dim)
Unroll the given `xs` into an `Array` of `Array`s along the given dimension `dim`.
# Examples
```jldoctest
julia> Flux.unstack([1 3 5 7; 2 4 6 8], 2)
4-element Array{Array{Int64,1},1}:
[1, 2]
[3, 4]
[5, 6]
[7, 8]
```
"""
unstack(xs, dim) = [copy(selectdim(xs, dim, i)) for i in 1:size(xs, dim)]
"""
@ -23,9 +125,16 @@ unstack(xs, dim) = [copy(selectdim(xs, dim, i)) for i in 1:size(xs, dim)]
Split `xs` into `n` parts.
```julia
julia> chunk(1:10, 3)
3-element Array{Array{Int64,1},1}:
# Examples
```jldoctest
julia> Flux.chunk(1:10, 3)
3-element Array{UnitRange{Int64},1}:
1:4
5:8
9:10
julia> Flux.chunk(collect(1:10), 3)
3-element Array{SubArray{Int64,1,Array{Int64,1},Tuple{UnitRange{Int64}},true},1}:
[1, 2, 3, 4]
[5, 6, 7, 8]
[9, 10]
@ -40,11 +149,12 @@ batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)
Count the number of times that each element of `xs` appears.
```julia
julia> frequencies(['a','b','b'])
# Examples
```jldoctest
julia> Flux.frequencies(['a','b','b'])
Dict{Char,Int64} with 2 entries:
'b' => 2
'a' => 1
'b' => 2
```
"""
function frequencies(xs)
@ -64,8 +174,9 @@ squeezebatch(x) = reshape(x, head(size(x)))
Batch the arrays in `xs` into a single array.
```julia
julia> batch([[1,2,3],[4,5,6]])
# Examples
```jldoctest
julia> Flux.batch([[1,2,3],[4,5,6]])
3×2 Array{Int64,2}:
1 4
2 5
@ -82,6 +193,25 @@ function batch(xs)
return data
end
"""
Return the given sequence padded with `p` up to a maximum length of `n`.
# Examples
```jldoctest
julia> rpad([1, 2], 4, 0)
4-element Array{Int64,1}:
1
2
0
0
julia> rpad([1, 2, 3], 2, 0)
3-element Array{Int64,1}:
1
2
3
```
"""
Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))]
"""
@ -90,8 +220,9 @@ Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))
Take a list of `N` sequences, and turn them into a single sequence where each
item is a batch of `N`. Short sequences will be padded by `pad`.
```julia
julia> batchseq([[1, 2, 3], [4, 5]], 0)
# Examples
```jldoctest
julia> Flux.batchseq([[1, 2, 3], [4, 5]], 0)
3-element Array{Array{Int64,1},1}:
[1, 4]
[2, 5]
@ -148,11 +279,15 @@ end
# Other
"""
Returns a function that when invoked, will only be triggered at most once
during `timeout` seconds. Normally, the throttled function will run
as much as it can, without ever going more than once per `wait` duration;
but if you'd like to disable the execution on the leading edge, pass
`leading=false`. To enable execution on the trailing edge, ditto.
throttle(f, timeout; leading=true, trailing=false)
Return a function that when invoked, will only be triggered at most once
during `timeout` seconds.
Normally, the throttled function will run as much as it can, without ever
going more than once per `wait` duration; but if you'd like to disable the
execution on the leading edge, pass `leading=false`. To enable execution on
the trailing edge, pass `trailing=true`.
"""
function throttle(f, timeout; leading=true, trailing=false)
cooldown = true

View File

@ -41,7 +41,8 @@ Random.seed!(0)
end
@testset "Docs" begin
if VERSION >= v"1.2"
if VERSION >= v"1.4"
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
doctest(Flux)
end
end