Merge #853
853: Improve docs r=CarloLucibello a=janEbert If you disagree with any of the changes, please tell me what to reverse or fix. I am unsure about the docstrings I added to `src/utils.jl` for `unsqueeze` and the `[un]stack` functions so please give those a more detailed look. Update Documenter.jl version for new features, fix deprecation warnings in `docs/make.jl` and import Flux for all doctests. Add missing docstrings to `src/utils.jl`, `src/layers/stateless.jl` and `src/data/`; add these and other missing functions to Markdown docs. Improve docstrings by... - fixing typos, - removing trailing or double whitespaces, - using `jldoctest` blocks where applicable, - fixing, updating or correctly setting up existing doctests, - improving consistency (for example, always use "# Examples" instead of other variants), - removing empty lines between docstrings and functions, - instead of mentioning keywords, put them into the docstring, - adding some missing but useful keywords, - adding references (`@ref`), - using LaTeX math where applicable, and - linking papers. Debatable stuff that is untouched: - BE/AE s/z irregularities (e.g. "normalise" versus "normalize") since most papers use the AE version while the Flux source code was written with BE spelling. - Names of normalization functions are capitalized ("Batch Normalization" instead of "batch normalization"). - Default values in argument lists have spaces around the equals sign (`arg = x` instead of `arg=x`). Co-authored-by: janEbert <janpublicebert@posteo.net>
This commit is contained in:
commit
7a32a703f0
14
docs/make.jl
14
docs/make.jl
|
@ -1,6 +1,8 @@
|
||||||
using Documenter, Flux, NNlib
|
using Documenter, Flux, NNlib
|
||||||
|
|
||||||
|
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
|
||||||
makedocs(modules=[Flux, NNlib],
|
makedocs(modules=[Flux, NNlib],
|
||||||
|
doctest = VERSION >= v"1.4",
|
||||||
sitename = "Flux",
|
sitename = "Flux",
|
||||||
pages = ["Home" => "index.md",
|
pages = ["Home" => "index.md",
|
||||||
"Building Models" =>
|
"Building Models" =>
|
||||||
|
@ -19,12 +21,16 @@ makedocs(modules=[Flux, NNlib],
|
||||||
"GPU Support" => "gpu.md",
|
"GPU Support" => "gpu.md",
|
||||||
"Saving & Loading" => "saving.md",
|
"Saving & Loading" => "saving.md",
|
||||||
"The Julia Ecosystem" => "ecosystem.md",
|
"The Julia Ecosystem" => "ecosystem.md",
|
||||||
|
"Utility Functions" => "utilities.md",
|
||||||
"Performance Tips" => "performance.md",
|
"Performance Tips" => "performance.md",
|
||||||
|
"Datasets" => "datasets.md",
|
||||||
"Community" => "community.md"],
|
"Community" => "community.md"],
|
||||||
format = Documenter.HTML(assets = ["assets/flux.css"],
|
format = Documenter.HTML(
|
||||||
analytics = "UA-36890222-9",
|
analytics = "UA-36890222-9",
|
||||||
prettyurls = haskey(ENV, "CI")))
|
assets = ["assets/flux.css"],
|
||||||
|
prettyurls = get(ENV, "CI", nothing) == "true"),
|
||||||
|
)
|
||||||
|
|
||||||
deploydocs(repo = "github.com/FluxML/Flux.jl.git",
|
deploydocs(repo = "github.com/FluxML/Flux.jl.git",
|
||||||
target = "build",
|
target = "build",
|
||||||
push_preview = true)
|
push_preview = true)
|
||||||
|
|
|
@ -3,4 +3,4 @@ Flux provides the `DataLoader` type in the `Flux.Data` module to handle iteratio
|
||||||
|
|
||||||
```@docs
|
```@docs
|
||||||
Flux.Data.DataLoader
|
Flux.Data.DataLoader
|
||||||
```
|
```
|
||||||
|
|
|
@ -31,6 +31,11 @@ julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||||
:c
|
:c
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```@docs
|
||||||
|
Flux.onehot
|
||||||
|
Flux.onecold
|
||||||
|
```
|
||||||
|
|
||||||
## Batches
|
## Batches
|
||||||
|
|
||||||
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `onecold` treats matrices as batches.
|
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `onecold` treats matrices as batches.
|
||||||
|
@ -52,3 +57,7 @@ julia> onecold(ans, [:a, :b, :c])
|
||||||
```
|
```
|
||||||
|
|
||||||
Note that these operations returned `OneHotVector` and `OneHotMatrix` rather than `Array`s. `OneHotVector`s behave like normal vectors but avoid any unnecessary cost compared to using an integer index directly. For example, multiplying a matrix with a one-hot vector simply slices out the relevant row of the matrix under the hood.
|
Note that these operations returned `OneHotVector` and `OneHotMatrix` rather than `Array`s. `OneHotVector`s behave like normal vectors but avoid any unnecessary cost compared to using an integer index directly. For example, multiplying a matrix with a one-hot vector simply slices out the relevant row of the matrix under the hood.
|
||||||
|
|
||||||
|
```@docs
|
||||||
|
Flux.onehotbatch
|
||||||
|
```
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Datasets
|
||||||
|
|
||||||
|
Flux includes several standard machine learning datasets.
|
||||||
|
|
||||||
|
```@docs
|
||||||
|
Flux.Data.Iris.features()
|
||||||
|
Flux.Data.Iris.labels()
|
||||||
|
Flux.Data.MNIST.images()
|
||||||
|
Flux.Data.MNIST.labels()
|
||||||
|
Flux.Data.FashionMNIST.images()
|
||||||
|
Flux.Data.FashionMNIST.labels()
|
||||||
|
Flux.Data.CMUDict.phones()
|
||||||
|
Flux.Data.CMUDict.symbols()
|
||||||
|
Flux.Data.CMUDict.rawdict()
|
||||||
|
Flux.Data.CMUDict.cmudict()
|
||||||
|
Flux.Data.Sentiment.train()
|
||||||
|
Flux.Data.Sentiment.test()
|
||||||
|
Flux.Data.Sentiment.dev()
|
||||||
|
```
|
||||||
|
|
|
@ -220,7 +220,7 @@ Flux.@functor Affine
|
||||||
|
|
||||||
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
|
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
|
||||||
|
|
||||||
For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advacned.md).
|
For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advanced.md).
|
||||||
|
|
||||||
## Utility functions
|
## Utility functions
|
||||||
|
|
||||||
|
@ -240,5 +240,5 @@ Currently limited to the following layers:
|
||||||
- `MeanPool`
|
- `MeanPool`
|
||||||
|
|
||||||
```@docs
|
```@docs
|
||||||
outdims
|
Flux.outdims
|
||||||
```
|
```
|
||||||
|
|
|
@ -32,6 +32,7 @@ RNN
|
||||||
LSTM
|
LSTM
|
||||||
GRU
|
GRU
|
||||||
Flux.Recur
|
Flux.Recur
|
||||||
|
Flux.reset!
|
||||||
```
|
```
|
||||||
|
|
||||||
## Other General Purpose Layers
|
## Other General Purpose Layers
|
||||||
|
@ -49,20 +50,22 @@ SkipConnection
|
||||||
These layers don't affect the structure of the network but may improve training times or reduce overfitting.
|
These layers don't affect the structure of the network but may improve training times or reduce overfitting.
|
||||||
|
|
||||||
```@docs
|
```@docs
|
||||||
|
Flux.normalise
|
||||||
BatchNorm
|
BatchNorm
|
||||||
Dropout
|
|
||||||
Flux.dropout
|
Flux.dropout
|
||||||
|
Dropout
|
||||||
AlphaDropout
|
AlphaDropout
|
||||||
LayerNorm
|
LayerNorm
|
||||||
|
InstanceNorm
|
||||||
GroupNorm
|
GroupNorm
|
||||||
```
|
```
|
||||||
|
|
||||||
### Testmode
|
### Testmode
|
||||||
|
|
||||||
Many normalisation layers behave differently under training and inference (testing). By default, Flux will automatically determine when a layer evaluation is part of training or inference. Still, depending on your use case, it may be helpful to manually specify when these layers should be treated as being trained or not. For this, Flux provides `testmode!`. When called on a model (e.g. a layer or chain of layers), this function will place the model into the mode specified.
|
Many normalisation layers behave differently under training and inference (testing). By default, Flux will automatically determine when a layer evaluation is part of training or inference. Still, depending on your use case, it may be helpful to manually specify when these layers should be treated as being trained or not. For this, Flux provides `Flux.testmode!`. When called on a model (e.g. a layer or chain of layers), this function will place the model into the mode specified.
|
||||||
|
|
||||||
```@docs
|
```@docs
|
||||||
testmode!
|
Flux.testmode!
|
||||||
trainmode!
|
trainmode!
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -64,3 +64,7 @@ julia> activations(c, rand(10))
|
||||||
julia> sum(norm, ans)
|
julia> sum(norm, ans)
|
||||||
2.1166067f0
|
2.1166067f0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```@docs
|
||||||
|
Flux.activations
|
||||||
|
```
|
||||||
|
|
|
@ -52,7 +52,7 @@ e.g.
|
||||||
```julia
|
```julia
|
||||||
function loss_total(xs::AbstractVector{<:Vector}, ys::AbstractVector{<:Vector})
|
function loss_total(xs::AbstractVector{<:Vector}, ys::AbstractVector{<:Vector})
|
||||||
sum(zip(xs, ys)) do (x, y_target)
|
sum(zip(xs, ys)) do (x, y_target)
|
||||||
y_pred = model(x) # evaluate the model
|
y_pred = model(x) # evaluate the model
|
||||||
return loss(y_pred, y_target)
|
return loss(y_pred, y_target)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -52,6 +52,7 @@ Momentum
|
||||||
Nesterov
|
Nesterov
|
||||||
RMSProp
|
RMSProp
|
||||||
ADAM
|
ADAM
|
||||||
|
RADAM
|
||||||
AdaMax
|
AdaMax
|
||||||
ADAGrad
|
ADAGrad
|
||||||
ADADelta
|
ADADelta
|
||||||
|
|
|
@ -32,6 +32,7 @@ Flux.train!(loss, ps, data, opt)
|
||||||
```
|
```
|
||||||
|
|
||||||
The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
|
The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
|
||||||
|
For a list of all built-in loss functions, check out the [layer reference](../models/layers.md).
|
||||||
|
|
||||||
At first glance it may seem strange that the model that we want to train is not part of the input arguments of `Flux.train!` too. However the target of the optimizer is not the model itself, but the objective function that represents the departure between modelled and observed data. In other words, the model is implicitly defined in the objective function, and there is no need to give it explicitly. Passing the objective function instead of the model and a cost function separately provides more flexibility, and the possibility of optimizing the calculations.
|
At first glance it may seem strange that the model that we want to train is not part of the input arguments of `Flux.train!` too. However the target of the optimizer is not the model itself, but the objective function that represents the departure between modelled and observed data. In other words, the model is implicitly defined in the objective function, and there is no need to give it explicitly. Passing the objective function instead of the model and a cost function separately provides more flexibility, and the possibility of optimizing the calculations.
|
||||||
|
|
||||||
|
@ -94,6 +95,10 @@ julia> @epochs 2 Flux.train!(...)
|
||||||
# Train for two epochs
|
# Train for two epochs
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```@docs
|
||||||
|
Flux.@epochs
|
||||||
|
```
|
||||||
|
|
||||||
## Callbacks
|
## Callbacks
|
||||||
|
|
||||||
`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example:
|
`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example:
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
# Utility Functions
|
||||||
|
|
||||||
|
Flux contains some utility functions for working with data; these functions
|
||||||
|
help create inputs for your models or batch your dataset.
|
||||||
|
Other functions can be used to initialize your layers or to regularly execute
|
||||||
|
callback functions.
|
||||||
|
|
||||||
|
## Working with Data
|
||||||
|
|
||||||
|
```@docs
|
||||||
|
Flux.unsqueeze
|
||||||
|
Flux.stack
|
||||||
|
Flux.unstack
|
||||||
|
Flux.chunk
|
||||||
|
Flux.frequencies
|
||||||
|
Flux.batch
|
||||||
|
Flux.batchseq
|
||||||
|
Base.rpad(v::AbstractVector, n::Integer, p)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Layer Initialization
|
||||||
|
|
||||||
|
These are primarily useful if you are planning to write your own layers.
|
||||||
|
Flux initializes convolutional layers and recurrent cells with `glorot_uniform`
|
||||||
|
by default.
|
||||||
|
To change the default on an applicable layer, pass the desired function with the
|
||||||
|
`init` keyword. For example:
|
||||||
|
```jldoctest; setup = :(using Flux)
|
||||||
|
julia> conv = Conv((3, 3), 1 => 8, relu; init=Flux.glorot_normal)
|
||||||
|
Conv((3, 3), 1=>8, relu)
|
||||||
|
```
|
||||||
|
|
||||||
|
```@docs
|
||||||
|
Flux.glorot_uniform
|
||||||
|
Flux.glorot_normal
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Abstraction
|
||||||
|
|
||||||
|
```@docs
|
||||||
|
Flux.destructure
|
||||||
|
```
|
||||||
|
|
||||||
|
## Callback Helpers
|
||||||
|
|
||||||
|
```@docs
|
||||||
|
Flux.throttle
|
||||||
|
Flux.stop
|
||||||
|
```
|
|
@ -24,18 +24,35 @@ function load()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
phones()
|
||||||
|
|
||||||
|
Return a `Vector` containing the phones used in the CMU Pronouncing Dictionary.
|
||||||
|
"""
|
||||||
function phones()
|
function phones()
|
||||||
load()
|
load()
|
||||||
Symbol.(first.(split.(split(read(deps("cmudict", "cmudict.phones"),String),
|
Symbol.(first.(split.(split(read(deps("cmudict", "cmudict.phones"),String),
|
||||||
"\n", keepempty = false), "\t")))
|
"\n", keepempty = false), "\t")))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
symbols()
|
||||||
|
|
||||||
|
Return a `Vector` containing the symbols used in the CMU Pronouncing Dictionary.
|
||||||
|
A symbol is a phone with optional auxiliary symbols, indicating for example the
|
||||||
|
amount of stress on the phone.
|
||||||
|
"""
|
||||||
function symbols()
|
function symbols()
|
||||||
load()
|
load()
|
||||||
Symbol.(split(read(deps("cmudict", "cmudict.symbols"),String),
|
Symbol.(split(read(deps("cmudict", "cmudict.symbols"),String),
|
||||||
"\n", keepempty = false))
|
"\n", keepempty = false))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
rawdict()
|
||||||
|
|
||||||
|
Return the unfiltered CMU Pronouncing Dictionary.
|
||||||
|
"""
|
||||||
function rawdict()
|
function rawdict()
|
||||||
load()
|
load()
|
||||||
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
|
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
|
||||||
|
@ -44,6 +61,14 @@ end
|
||||||
|
|
||||||
validword(s) = isascii(s) && occursin(r"^[\w\-\.]+$", s)
|
validword(s) = isascii(s) && occursin(r"^[\w\-\.]+$", s)
|
||||||
|
|
||||||
|
"""
|
||||||
|
cmudict()
|
||||||
|
|
||||||
|
Return a filtered CMU Pronouncing Dictionary.
|
||||||
|
|
||||||
|
It is filtered so each word contains only ASCII characters and a combination of
|
||||||
|
word characters (as determined by the regex engine using `\\w`), '-' and '.'.
|
||||||
|
"""
|
||||||
cmudict() = filter(p -> validword(p.first), rawdict())
|
cmudict() = filter(p -> validword(p.first), rawdict())
|
||||||
|
|
||||||
alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.']
|
alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.']
|
||||||
|
|
|
@ -33,9 +33,10 @@ const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte")
|
||||||
|
|
||||||
Load the Fashion-MNIST images.
|
Load the Fashion-MNIST images.
|
||||||
|
|
||||||
Each image is a 28×28 array of `Gray` colour values (see Colors.jl).
|
Each image is a 28×28 array of `Gray` colour values
|
||||||
|
(see [Colors.jl](https://github.com/JuliaGraphics/Colors.jl)).
|
||||||
|
|
||||||
Returns the 60,000 training images by default; pass `:test` to retreive the
|
Return the 60,000 training images by default; pass `:test` to retrieve the
|
||||||
10,000 test images.
|
10,000 test images.
|
||||||
"""
|
"""
|
||||||
function images(set = :train)
|
function images(set = :train)
|
||||||
|
@ -49,10 +50,10 @@ end
|
||||||
labels()
|
labels()
|
||||||
labels(:test)
|
labels(:test)
|
||||||
|
|
||||||
Load the labels corresponding to each of the images returned from `images()`.
|
Load the labels corresponding to each of the images returned from [`images()`](@ref).
|
||||||
Each label is a number from 0-9.
|
Each label is a number from 0-9.
|
||||||
|
|
||||||
Returns the 60,000 training labels by default; pass `:test` to retreive the
|
Return the 60,000 training labels by default; pass `:test` to retrieve the
|
||||||
10,000 test labels.
|
10,000 test labels.
|
||||||
"""
|
"""
|
||||||
function labels(set = :train)
|
function labels(set = :train)
|
||||||
|
|
|
@ -2,13 +2,12 @@
|
||||||
Fisher's classic iris dataset.
|
Fisher's classic iris dataset.
|
||||||
|
|
||||||
Measurements from 3 different species of iris: setosa, versicolor and
|
Measurements from 3 different species of iris: setosa, versicolor and
|
||||||
virginica. There are 50 examples of each species.
|
virginica. There are 50 examples of each species.
|
||||||
|
|
||||||
There are 4 measurements for each example: sepal length, sepal width, petal
|
There are 4 measurements for each example: sepal length, sepal width,
|
||||||
length and petal width. The measurements are in centimeters.
|
petal length and petal width. The measurements are in centimeters.
|
||||||
|
|
||||||
The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris).
|
The module retrieves the data from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
module Iris
|
module Iris
|
||||||
|
|
||||||
|
@ -33,9 +32,7 @@ end
|
||||||
Get the labels of the iris dataset, a 150 element array of strings listing the
|
Get the labels of the iris dataset, a 150 element array of strings listing the
|
||||||
species of each example.
|
species of each example.
|
||||||
|
|
||||||
```jldoctest
|
```jldoctest; setup = :(Flux.Data.Iris.load())
|
||||||
julia> using Flux
|
|
||||||
|
|
||||||
julia> labels = Flux.Data.Iris.labels();
|
julia> labels = Flux.Data.Iris.labels();
|
||||||
|
|
||||||
julia> summary(labels)
|
julia> summary(labels)
|
||||||
|
@ -54,13 +51,11 @@ end
|
||||||
"""
|
"""
|
||||||
features()
|
features()
|
||||||
|
|
||||||
Get the features of the iris dataset. This is a 4x150 matrix of Float64
|
Get the features of the iris dataset. This is a 4x150 matrix of Float64
|
||||||
elements. It has a row for each feature (sepal length, sepal width,
|
elements. It has a row for each feature (sepal length, sepal width,
|
||||||
petal length, petal width) and a column for each example.
|
petal length, petal width) and a column for each example.
|
||||||
|
|
||||||
```jldoctest
|
```jldoctest; setup = :(Flux.Data.Iris.load())
|
||||||
julia> using Flux
|
|
||||||
|
|
||||||
julia> features = Flux.Data.Iris.features();
|
julia> features = Flux.Data.Iris.features();
|
||||||
|
|
||||||
julia> summary(features)
|
julia> summary(features)
|
||||||
|
|
|
@ -83,9 +83,10 @@ getfeatures(io::IO, index::Integer) = vec(getimage(io, index))
|
||||||
|
|
||||||
Load the MNIST images.
|
Load the MNIST images.
|
||||||
|
|
||||||
Each image is a 28×28 array of `Gray` colour values (see Colors.jl).
|
Each image is a 28×28 array of `Gray` colour values
|
||||||
|
(see [Colors.jl](https://github.com/JuliaGraphics/Colors.jl)).
|
||||||
|
|
||||||
Returns the 60,000 training images by default; pass `:test` to retreive the
|
Return the 60,000 training images by default; pass `:test` to retrieve the
|
||||||
10,000 test images.
|
10,000 test images.
|
||||||
"""
|
"""
|
||||||
function images(set = :train)
|
function images(set = :train)
|
||||||
|
@ -99,10 +100,10 @@ end
|
||||||
labels()
|
labels()
|
||||||
labels(:test)
|
labels(:test)
|
||||||
|
|
||||||
Load the labels corresponding to each of the images returned from `images()`.
|
Load the labels corresponding to each of the images returned from [`images()`](@ref).
|
||||||
Each label is a number from 0-9.
|
Each label is a number from 0-9.
|
||||||
|
|
||||||
Returns the 60,000 training labels by default; pass `:test` to retreive the
|
Return the 60,000 training labels by default; pass `:test` to retrieve the
|
||||||
10,000 test labels.
|
10,000 test labels.
|
||||||
"""
|
"""
|
||||||
function labels(set = :train)
|
function labels(set = :train)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
"Stanford Sentiment Treebank dataset."
|
||||||
module Sentiment
|
module Sentiment
|
||||||
|
|
||||||
using ZipFile
|
using ZipFile
|
||||||
|
@ -39,8 +40,28 @@ function gettrees(name)
|
||||||
return parsetree.(ss)
|
return parsetree.(ss)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
train()
|
||||||
|
|
||||||
|
Return the train split of the Stanford Sentiment Treebank.
|
||||||
|
The data is in [treebank](https://en.wikipedia.org/wiki/Treebank) format.
|
||||||
|
"""
|
||||||
train() = gettrees("train")
|
train() = gettrees("train")
|
||||||
|
|
||||||
|
"""
|
||||||
|
test()
|
||||||
|
|
||||||
|
Return the test split of the Stanford Sentiment Treebank.
|
||||||
|
The data is in [treebank](https://en.wikipedia.org/wiki/Treebank) format.
|
||||||
|
"""
|
||||||
test() = gettrees("test")
|
test() = gettrees("test")
|
||||||
|
|
||||||
|
"""
|
||||||
|
dev()
|
||||||
|
|
||||||
|
Return the dev split of the Stanford Sentiment Treebank.
|
||||||
|
The data is in [treebank](https://en.wikipedia.org/wiki/Treebank) format.
|
||||||
|
"""
|
||||||
dev() = gettrees("dev")
|
dev() = gettrees("dev")
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
|
@ -4,17 +4,23 @@
|
||||||
Chain multiple layers / functions together, so that they are called in sequence
|
Chain multiple layers / functions together, so that they are called in sequence
|
||||||
on a given input.
|
on a given input.
|
||||||
|
|
||||||
```julia
|
|
||||||
m = Chain(x -> x^2, x -> x+1)
|
|
||||||
m(5) == 26
|
|
||||||
|
|
||||||
m = Chain(Dense(10, 5), Dense(5, 2))
|
|
||||||
x = rand(10)
|
|
||||||
m(x) == m[2](m[1](x))
|
|
||||||
```
|
|
||||||
|
|
||||||
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
|
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
|
||||||
`m[1:3](x)` will calculate the output of the first three layers.
|
`m[1:3](x)` will calculate the output of the first three layers.
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
```jldoctest
|
||||||
|
julia> m = Chain(x -> x^2, x -> x+1);
|
||||||
|
|
||||||
|
julia> m(5) == 26
|
||||||
|
true
|
||||||
|
|
||||||
|
julia> m = Chain(Dense(10, 5), Dense(5, 2));
|
||||||
|
|
||||||
|
julia> x = rand(10);
|
||||||
|
|
||||||
|
julia> m(x) == m[2](m[1](x))
|
||||||
|
true
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
struct Chain{T<:Tuple}
|
struct Chain{T<:Tuple}
|
||||||
layers::T
|
layers::T
|
||||||
|
@ -60,6 +66,7 @@ outdims(c::Chain, isize) = foldl(∘, map(l -> (x -> outdims(l, x)), c.layers))(
|
||||||
# only slightly changed to better handle interaction with Zygote @dsweber2
|
# only slightly changed to better handle interaction with Zygote @dsweber2
|
||||||
"""
|
"""
|
||||||
activations(c::Chain, input)
|
activations(c::Chain, input)
|
||||||
|
|
||||||
Calculate the forward results of each layers in Chain `c` with `input` as model input.
|
Calculate the forward results of each layers in Chain `c` with `input` as model input.
|
||||||
"""
|
"""
|
||||||
function activations(c::Chain, input)
|
function activations(c::Chain, input)
|
||||||
|
@ -78,22 +85,22 @@ extraChain(::Tuple{}, x) = ()
|
||||||
"""
|
"""
|
||||||
Dense(in::Integer, out::Integer, σ = identity)
|
Dense(in::Integer, out::Integer, σ = identity)
|
||||||
|
|
||||||
Creates a traditional `Dense` layer with parameters `W` and `b`.
|
Create a traditional `Dense` layer with parameters `W` and `b`.
|
||||||
|
|
||||||
y = σ.(W * x .+ b)
|
y = σ.(W * x .+ b)
|
||||||
|
|
||||||
The input `x` must be a vector of length `in`, or a batch of vectors represented
|
The input `x` must be a vector of length `in`, or a batch of vectors represented
|
||||||
as an `in × N` matrix. The out `y` will be a vector or batch of length `out`.
|
as an `in × N` matrix. The out `y` will be a vector or batch of length `out`.
|
||||||
|
|
||||||
```julia
|
# Examples
|
||||||
|
```jldoctest; setup = :(using Random; Random.seed!(0))
|
||||||
julia> d = Dense(5, 2)
|
julia> d = Dense(5, 2)
|
||||||
Dense(5, 2)
|
Dense(5, 2)
|
||||||
|
|
||||||
julia> d(rand(5))
|
julia> d(rand(5))
|
||||||
Array{Float64,1}:
|
2-element Array{Float32,1}:
|
||||||
0.00257447
|
-0.16210233
|
||||||
-0.00449443
|
0.12311903```
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
struct Dense{F,S,T}
|
struct Dense{F,S,T}
|
||||||
W::S
|
W::S
|
||||||
|
@ -145,7 +152,7 @@ outdims(l::Dense, isize) = (size(l.W)[1],)
|
||||||
"""
|
"""
|
||||||
Diagonal(in::Integer)
|
Diagonal(in::Integer)
|
||||||
|
|
||||||
Creates an element-wise linear transformation layer with learnable
|
Create an element-wise linear transformation layer with learnable
|
||||||
vectors `α` and `β`:
|
vectors `α` and `β`:
|
||||||
|
|
||||||
y = α .* x .+ β
|
y = α .* x .+ β
|
||||||
|
@ -176,18 +183,11 @@ outdims(l::Diagonal, isize) = (length(l.α),)
|
||||||
"""
|
"""
|
||||||
Maxout(over)
|
Maxout(over)
|
||||||
|
|
||||||
`Maxout` is a neural network layer, which has a number of internal layers,
|
The [Maxout](https://arxiv.org/pdf/1302.4389.pdf) layer has a number of
|
||||||
which all have the same input, and the maxout returns the elementwise maximium
|
internal layers which all receive the same input. It returns the elementwise
|
||||||
of the internal layers' outputs.
|
maximum of the internal layers' outputs.
|
||||||
|
|
||||||
Maxout over linear dense layers satisfies the univeral approximation theorem.
|
Maxout over linear dense layers satisfies the univeral approximation theorem.
|
||||||
|
|
||||||
Reference:
|
|
||||||
Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, and Yoshua Bengio.
|
|
||||||
2013. Maxout networks.
|
|
||||||
In Proceedings of the 30th International Conference on International Conference on Machine Learning - Volume 28 (ICML'13),
|
|
||||||
Sanjoy Dasgupta and David McAllester (Eds.), Vol. 28. JMLR.org III-1319-III-1327.
|
|
||||||
https://arxiv.org/pdf/1302.4389.pdf
|
|
||||||
"""
|
"""
|
||||||
struct Maxout{FS<:Tuple}
|
struct Maxout{FS<:Tuple}
|
||||||
over::FS
|
over::FS
|
||||||
|
@ -196,17 +196,18 @@ end
|
||||||
"""
|
"""
|
||||||
Maxout(f, n_alts)
|
Maxout(f, n_alts)
|
||||||
|
|
||||||
Constructs a Maxout layer over `n_alts` instances of the layer given by `f`.
|
Construct a Maxout layer over `n_alts` instances of the layer given by `f`.
|
||||||
The function takes no arguement and should return some callable layer.
|
The function takes no arguments and should return some callable layer.
|
||||||
Conventionally this is a linear dense layer.
|
Conventionally, this is a linear dense layer.
|
||||||
|
|
||||||
For example the following example which
|
# Examples
|
||||||
will construct a `Maxout` layer over 4 internal dense linear layers,
|
|
||||||
each identical in structure (784 inputs, 128 outputs).
|
This constructs a `Maxout` layer over 4 internal dense linear layers, each
|
||||||
|
identical in structure (784 inputs, 128 outputs):
|
||||||
```julia
|
```julia
|
||||||
insize = 784
|
insize = 784
|
||||||
outsize = 128
|
outsize = 128
|
||||||
Maxout(()->Dense(insize, outsize), 4)
|
Maxout(()->Dense(insize, outsize), 4)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
function Maxout(f, n_alts)
|
function Maxout(f, n_alts)
|
||||||
|
@ -223,16 +224,18 @@ end
|
||||||
outdims(l::Maxout, isize) = outdims(first(l.over), isize)
|
outdims(l::Maxout, isize) = outdims(first(l.over), isize)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
SkipConnection(layers, connection)
|
SkipConnection(layer, connection)
|
||||||
|
|
||||||
Creates a Skip Connection, of a layer or `Chain` of consecutive layers
|
Create a skip connection which consists of a layer or `Chain` of consecutive
|
||||||
plus a shortcut connection. The connection function will combine the result of the layers
|
layers and a shortcut connection linking the block's input to the output
|
||||||
with the original input, to give the final output.
|
through a user-supplied 2-argument callable. The first argument to the callable
|
||||||
|
will be propagated through the given `layer` while the second is the unchanged,
|
||||||
|
"skipped" input.
|
||||||
|
|
||||||
The simplest 'ResNet'-type connection is just `SkipConnection(layer, +)`,
|
The simplest "ResNet"-type connection is just `SkipConnection(layer, +)`,
|
||||||
and requires the output of the layers to be the same shape as the input.
|
and requires the output of the layers to be the same shape as the input.
|
||||||
Here is a more complicated example:
|
Here is a more complicated example:
|
||||||
```
|
```julia
|
||||||
m = Conv((3,3), 4=>7, pad=(1,1))
|
m = Conv((3,3), 4=>7, pad=(1,1))
|
||||||
x = ones(5,5,4,10);
|
x = ones(5,5,4,10);
|
||||||
size(m(x)) == (5, 5, 7, 10)
|
size(m(x)) == (5, 5, 7, 10)
|
||||||
|
|
|
@ -8,25 +8,26 @@ _convtransoutdims(isize, ksize, ssize, dsize, pad) = (isize .- 1).*ssize .+ 1 .+
|
||||||
expand(N, i::Tuple) = i
|
expand(N, i::Tuple) = i
|
||||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||||
"""
|
"""
|
||||||
Conv(size, in=>out)
|
Conv(size, in => out, σ = identity; init = glorot_uniform,
|
||||||
Conv(size, in=>out, relu)
|
stride = 1, pad = 0, dilation = 1)
|
||||||
|
|
||||||
Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
|
Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||||
`in` and `out` specify the number of input and output channels respectively.
|
`in` and `out` specify the number of input and output channels respectively.
|
||||||
|
|
||||||
Example: Applying Conv layer to a 1-channel input using a 2x2 window size,
|
|
||||||
giving us a 16-channel output. Output is activated with ReLU.
|
|
||||||
|
|
||||||
size = (2,2)
|
|
||||||
in = 1
|
|
||||||
out = 16
|
|
||||||
Conv((2, 2), 1=>16, relu)
|
|
||||||
|
|
||||||
Data should be stored in WHCN order (width, height, # channels, batch size).
|
Data should be stored in WHCN order (width, height, # channels, batch size).
|
||||||
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||||
and a batch of 50 would be a `100×100×3×50` array.
|
and a batch of 50 would be a `100×100×3×50` array.
|
||||||
|
|
||||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
# Examples
|
||||||
|
|
||||||
|
Apply a `Conv` layer to a 1-channel input using a 2×2 window size, giving us a
|
||||||
|
16-channel output. Output is activated with ReLU.
|
||||||
|
```julia
|
||||||
|
size = (2,2)
|
||||||
|
in = 1
|
||||||
|
out = 16
|
||||||
|
Conv(size, in => out, relu)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
struct Conv{N,M,F,A,V}
|
struct Conv{N,M,F,A,V}
|
||||||
σ::F
|
σ::F
|
||||||
|
@ -76,8 +77,8 @@ end
|
||||||
"""
|
"""
|
||||||
outdims(l::Conv, isize::Tuple)
|
outdims(l::Conv, isize::Tuple)
|
||||||
|
|
||||||
Calculate the output dimensions given the input dimensions, `isize`.
|
Calculate the output dimensions given the input dimensions `isize`.
|
||||||
Batch size and channel size are ignored as per `NNlib.jl`.
|
Batch size and channel size are ignored as per [NNlib.jl](https://github.com/FluxML/NNlib.jl).
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
m = Conv((3, 3), 3 => 16)
|
m = Conv((3, 3), 3 => 16)
|
||||||
|
@ -89,17 +90,15 @@ outdims(l::Conv, isize) =
|
||||||
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
|
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ConvTranspose(size, in=>out)
|
ConvTranspose(size, in => out, σ = identity; init = glorot_uniform,
|
||||||
ConvTranspose(size, in=>out, relu)
|
stride = 1, pad = 0, dilation = 1)
|
||||||
|
|
||||||
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
|
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
|
||||||
`in` and `out` specify the number of input and output channels respectively.
|
`in` and `out` specify the number of input and output channels respectively.
|
||||||
|
|
||||||
Data should be stored in WHCN order (width, height, # channels, # batches).
|
Data should be stored in WHCN order (width, height, # channels, batch size).
|
||||||
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||||
and a batch of 50 would be a `100×100×3×50` array.
|
and a batch of 50 would be a `100×100×3×50` array.
|
||||||
|
|
||||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
|
||||||
"""
|
"""
|
||||||
struct ConvTranspose{N,M,F,A,V}
|
struct ConvTranspose{N,M,F,A,V}
|
||||||
σ::F
|
σ::F
|
||||||
|
@ -165,18 +164,16 @@ end
|
||||||
outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad)
|
outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
DepthwiseConv(size, in=>out)
|
DepthwiseConv(size, in => out, σ = identity; init = glorot_uniform,
|
||||||
DepthwiseConv(size, in=>out, relu)
|
stride = 1, pad = 0, dilation = 1)
|
||||||
|
|
||||||
Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
|
Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||||
`in` and `out` specify the number of input and output channels respectively.
|
`in` and `out` specify the number of input and output channels respectively.
|
||||||
Note that `out` must be an integer multiple of `in`.
|
Note that `out` must be an integer multiple of `in`.
|
||||||
|
|
||||||
Data should be stored in WHCN order (width, height, # channels, # batches).
|
Data should be stored in WHCN order (width, height, # channels, batch size).
|
||||||
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||||
and a batch of 50 would be a `100×100×3×50` array.
|
and a batch of 50 would be a `100×100×3×50` array.
|
||||||
|
|
||||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
|
||||||
"""
|
"""
|
||||||
struct DepthwiseConv{N,M,F,A,V}
|
struct DepthwiseConv{N,M,F,A,V}
|
||||||
σ::F
|
σ::F
|
||||||
|
@ -233,25 +230,26 @@ outdims(l::DepthwiseConv, isize) =
|
||||||
output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
|
output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
CrossCor(size, in=>out)
|
CrossCor(size, in => out, σ = identity; init = glorot_uniform,
|
||||||
CrossCor(size, in=>out, relu)
|
stride = 1, pad = 0, dilation = 1)
|
||||||
|
|
||||||
Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
|
Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||||
`in` and `out` specify the number of input and output channels respectively.
|
`in` and `out` specify the number of input and output channels respectively.
|
||||||
|
|
||||||
Example: Applying CrossCor layer to a 1-channel input using a 2x2 window size,
|
Data should be stored in WHCN order (width, height, # channels, batch size).
|
||||||
giving us a 16-channel output. Output is activated with ReLU.
|
|
||||||
|
|
||||||
size = (2,2)
|
|
||||||
in = 1
|
|
||||||
out = 16
|
|
||||||
CrossCor((2, 2), 1=>16, relu)
|
|
||||||
|
|
||||||
Data should be stored in WHCN order (width, height, # channels, # batches).
|
|
||||||
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||||
and a batch of 50 would be a `100×100×3×50` array.
|
and a batch of 50 would be a `100×100×3×50` array.
|
||||||
|
|
||||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
# Examples
|
||||||
|
|
||||||
|
Apply a `CrossCor` layer to a 1-channel input using a 2×2 window size, giving us a
|
||||||
|
16-channel output. Output is activated with ReLU.
|
||||||
|
```julia
|
||||||
|
size = (2,2)
|
||||||
|
in = 1
|
||||||
|
out = 16
|
||||||
|
CrossCor((2, 2), 1=>16, relu)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
struct CrossCor{N,M,F,A,V}
|
struct CrossCor{N,M,F,A,V}
|
||||||
σ::F
|
σ::F
|
||||||
|
@ -357,11 +355,9 @@ function Base.show(io::IO, g::GlobalMeanPool)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
MaxPool(k)
|
MaxPool(k; pad = 0, stride = k)
|
||||||
|
|
||||||
Max pooling layer. `k` stands for the size of the window for each dimension of the input.
|
Max pooling layer. `k` is the size of the window for each dimension of the input.
|
||||||
|
|
||||||
Takes the keyword arguments `pad` and `stride`.
|
|
||||||
"""
|
"""
|
||||||
struct MaxPool{N,M}
|
struct MaxPool{N,M}
|
||||||
k::NTuple{N,Int}
|
k::NTuple{N,Int}
|
||||||
|
@ -388,11 +384,9 @@ end
|
||||||
outdims(l::MaxPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
|
outdims(l::MaxPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
MeanPool(k)
|
MeanPool(k; pad = 0, stride = k)
|
||||||
|
|
||||||
Mean pooling layer. `k` stands for the size of the window for each dimension of the input.
|
Mean pooling layer. `k` is the size of the window for each dimension of the input.
|
||||||
|
|
||||||
Takes the keyword arguments `pad` and `stride`.
|
|
||||||
"""
|
"""
|
||||||
struct MeanPool{N,M}
|
struct MeanPool{N,M}
|
||||||
k::NTuple{N,Int}
|
k::NTuple{N,Int}
|
||||||
|
|
|
@ -10,14 +10,14 @@ _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(s
|
||||||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
dropout(p, dims = :)
|
dropout(x, p; dims = :)
|
||||||
|
|
||||||
Dropout function. For each input, either sets that input to `0` (with probability
|
The dropout function. For each input, either sets that input to `0` (with probability
|
||||||
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specify the unbroadcasted
|
`p`) or scales it by `1 / (1 - p)`. `dims` specifies the unbroadcasted dimensions,
|
||||||
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
|
e.g. `dims=1` applies dropout along columns and `dims=2` along rows.
|
||||||
used as a regularisation, i.e. it reduces overfitting during training.
|
This is used as a regularisation, i.e. it reduces overfitting during training.
|
||||||
|
|
||||||
See also [`Dropout`](@ref).
|
See also the [`Dropout`](@ref) layer.
|
||||||
"""
|
"""
|
||||||
dropout(x, p; dims = :) = x
|
dropout(x, p; dims = :) = x
|
||||||
|
|
||||||
|
@ -30,9 +30,9 @@ end
|
||||||
"""
|
"""
|
||||||
Dropout(p, dims = :)
|
Dropout(p, dims = :)
|
||||||
|
|
||||||
A Dropout layer. In the forward pass, applies the [`dropout`](@ref) function on the input.
|
Dropout layer. In the forward pass, apply the [`Flux.dropout`](@ref) function on the input.
|
||||||
|
|
||||||
Does nothing to the input once [`testmode!`](@ref) is true.
|
Does nothing to the input once [`Flux.testmode!`](@ref) is `true`.
|
||||||
"""
|
"""
|
||||||
mutable struct Dropout{F,D}
|
mutable struct Dropout{F,D}
|
||||||
p::F
|
p::F
|
||||||
|
@ -64,10 +64,11 @@ end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
AlphaDropout(p)
|
AlphaDropout(p)
|
||||||
|
|
||||||
A dropout layer. It is used in Self-Normalizing Neural Networks.
|
A dropout layer. Used in
|
||||||
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
|
[Self-Normalizing Neural Networks](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf).
|
||||||
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
|
The AlphaDropout layer ensures that mean and variance of activations
|
||||||
|
remain the same as before.
|
||||||
|
|
||||||
Does nothing to the input once [`testmode!`](@ref) is true.
|
Does nothing to the input once [`testmode!`](@ref) is true.
|
||||||
"""
|
"""
|
||||||
|
@ -100,8 +101,8 @@ testmode!(m::AlphaDropout, mode = true) =
|
||||||
LayerNorm(h::Integer)
|
LayerNorm(h::Integer)
|
||||||
|
|
||||||
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
|
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
|
||||||
used with recurrent hidden states of size `h`. Normalises the mean/stddev of
|
used with recurrent hidden states of size `h`. Normalises the mean and standard
|
||||||
each input before applying a per-neuron gain/bias.
|
deviation of each input before applying a per-neuron gain/bias.
|
||||||
"""
|
"""
|
||||||
struct LayerNorm{T}
|
struct LayerNorm{T}
|
||||||
diag::Diagonal{T}
|
diag::Diagonal{T}
|
||||||
|
@ -123,8 +124,8 @@ end
|
||||||
initβ = zeros, initγ = ones,
|
initβ = zeros, initγ = ones,
|
||||||
ϵ = 1e-8, momentum = .1)
|
ϵ = 1e-8, momentum = .1)
|
||||||
|
|
||||||
Batch Normalization layer. The `channels` input should be the size of the
|
[Batch Normalization](https://arxiv.org/pdf/1502.03167.pdf) layer.
|
||||||
channel dimension in your data (see below).
|
`channels` should be the size of the channel dimension in your data (see below).
|
||||||
|
|
||||||
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
|
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
|
||||||
a batch of feature vectors this is just the data dimension, for `WHCN` images
|
a batch of feature vectors this is just the data dimension, for `WHCN` images
|
||||||
|
@ -136,10 +137,7 @@ per-channel `bias` and `scale` parameters).
|
||||||
|
|
||||||
Use [`testmode!`](@ref) during inference.
|
Use [`testmode!`](@ref) during inference.
|
||||||
|
|
||||||
See [Batch Normalization: Accelerating Deep Network Training by Reducing
|
# Examples
|
||||||
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```julia
|
```julia
|
||||||
m = Chain(
|
m = Chain(
|
||||||
Dense(28^2, 64),
|
Dense(28^2, 64),
|
||||||
|
@ -213,37 +211,6 @@ function Base.show(io::IO, l::BatchNorm)
|
||||||
print(io, ")")
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
InstanceNorm(channels::Integer, σ = identity;
|
|
||||||
initβ = zeros, initγ = ones,
|
|
||||||
ϵ = 1e-8, momentum = .1)
|
|
||||||
|
|
||||||
Instance Normalization layer. The `channels` input should be the size of the
|
|
||||||
channel dimension in your data (see below).
|
|
||||||
|
|
||||||
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
|
|
||||||
a batch of feature vectors this is just the data dimension, for `WHCN` images
|
|
||||||
it's the usual channel dimension.)
|
|
||||||
|
|
||||||
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
|
|
||||||
shifts them to have a new mean and variance (corresponding to the learnable,
|
|
||||||
per-channel `bias` and `scale` parameters).
|
|
||||||
|
|
||||||
Use [`testmode!`](@ref) during inference.
|
|
||||||
|
|
||||||
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```julia
|
|
||||||
m = Chain(
|
|
||||||
Dense(28^2, 64),
|
|
||||||
InstanceNorm(64, relu),
|
|
||||||
Dense(64, 10),
|
|
||||||
InstanceNorm(10),
|
|
||||||
softmax)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||||
|
|
||||||
mutable struct InstanceNorm{F,V,W,N}
|
mutable struct InstanceNorm{F,V,W,N}
|
||||||
|
@ -258,6 +225,34 @@ mutable struct InstanceNorm{F,V,W,N}
|
||||||
end
|
end
|
||||||
|
|
||||||
# TODO: deprecate in v0.11
|
# TODO: deprecate in v0.11
|
||||||
|
"""
|
||||||
|
InstanceNorm(channels::Integer, σ = identity;
|
||||||
|
initβ = zeros, initγ = ones,
|
||||||
|
ϵ = 1e-8, momentum = .1)
|
||||||
|
|
||||||
|
[Instance Normalization](https://arxiv.org/abs/1607.08022) layer.
|
||||||
|
`channels` should be the size of the channel dimension in your data (see below).
|
||||||
|
|
||||||
|
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
|
||||||
|
a batch of feature vectors this is just the data dimension, for `WHCN` images
|
||||||
|
it's the usual channel dimension.)
|
||||||
|
|
||||||
|
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
|
||||||
|
shifts them to have a new mean and variance (corresponding to the learnable,
|
||||||
|
per-channel `bias` and `scale` parameters).
|
||||||
|
|
||||||
|
Use [`testmode!`](@ref) during inference.
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
```julia
|
||||||
|
m = Chain(
|
||||||
|
Dense(28^2, 64),
|
||||||
|
InstanceNorm(64, relu),
|
||||||
|
Dense(64, 10),
|
||||||
|
InstanceNorm(10),
|
||||||
|
softmax)
|
||||||
|
```
|
||||||
|
"""
|
||||||
InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) = InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
|
InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) = InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
|
||||||
|
|
||||||
InstanceNorm(chs::Integer, λ = identity;
|
InstanceNorm(chs::Integer, λ = identity;
|
||||||
|
@ -316,28 +311,27 @@ function Base.show(io::IO, l::InstanceNorm)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Group Normalization.
|
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||||
This layer can outperform Batch-Normalization and Instance-Normalization.
|
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
|
||||||
|
ϵ = 1f-5, momentum = 0.1f0)
|
||||||
|
|
||||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
[Group Normalization](https://arxiv.org/pdf/1803.08494.pdf) layer.
|
||||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
|
This layer can outperform Batch Normalization and Instance Normalization.
|
||||||
ϵ = 1f-5, momentum = 0.1f0)
|
|
||||||
|
|
||||||
``chs`` is the number of channels, the channel dimension of your input.
|
`chs` is the number of channels, the channel dimension of your input.
|
||||||
For an array of N dimensions, the (N-1)th index is the channel dimension.
|
For an array of N dimensions, the `N-1`th index is the channel dimension.
|
||||||
|
|
||||||
``G`` is the number of groups along which the statistics would be computed.
|
`G` is the number of groups along which the statistics are computed.
|
||||||
The number of channels must be an integer multiple of the number of groups.
|
The number of channels must be an integer multiple of the number of groups.
|
||||||
|
|
||||||
Use [`testmode!`](@ref) during inference.
|
Use [`testmode!`](@ref) during inference.
|
||||||
|
|
||||||
Example:
|
# Examples
|
||||||
```
|
```julia
|
||||||
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
|
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
|
||||||
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used
|
GroupNorm(32,16))
|
||||||
|
# 32 channels, 16 groups (G = 16), thus 2 channels per group used
|
||||||
```
|
```
|
||||||
|
|
||||||
Link : https://arxiv.org/pdf/1803.08494.pdf
|
|
||||||
"""
|
"""
|
||||||
mutable struct GroupNorm{F,V,W,N,T}
|
mutable struct GroupNorm{F,V,W,N,T}
|
||||||
G::T # number of groups
|
G::T # number of groups
|
||||||
|
|
|
@ -12,16 +12,16 @@ in the background. `cell` should be a model of the form:
|
||||||
|
|
||||||
h, y = cell(h, x...)
|
h, y = cell(h, x...)
|
||||||
|
|
||||||
For example, here's a recurrent network that keeps a running total of its inputs.
|
For example, here's a recurrent network that keeps a running total of its inputs:
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
accum(h, x) = (h+x, x)
|
accum(h, x) = (h + x, x)
|
||||||
rnn = Flux.Recur(accum, 0)
|
rnn = Flux.Recur(accum, 0)
|
||||||
rnn(2) # 2
|
rnn(2) # 2
|
||||||
rnn(3) # 3
|
rnn(3) # 3
|
||||||
rnn.state # 5
|
rnn.state # 5
|
||||||
rnn.(1:10) # apply to a sequence
|
rnn.(1:10) # apply to a sequence
|
||||||
rnn.state # 60
|
rnn.state # 60
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
mutable struct Recur{T}
|
mutable struct Recur{T}
|
||||||
|
@ -47,9 +47,10 @@ Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||||
|
|
||||||
Reset the hidden state of a recurrent layer back to its original value.
|
Reset the hidden state of a recurrent layer back to its original value.
|
||||||
|
|
||||||
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to:
|
||||||
|
```julia
|
||||||
rnn.state = hidden(rnn.cell)
|
rnn.state = hidden(rnn.cell)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
reset!(m::Recur) = (m.state = m.init)
|
reset!(m::Recur) = (m.state = m.init)
|
||||||
reset!(m) = foreach(reset!, functor(m)[1])
|
reset!(m) = foreach(reset!, functor(m)[1])
|
||||||
|
@ -135,8 +136,8 @@ Base.show(io::IO, l::LSTMCell) =
|
||||||
"""
|
"""
|
||||||
LSTM(in::Integer, out::Integer)
|
LSTM(in::Integer, out::Integer)
|
||||||
|
|
||||||
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
|
[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)
|
||||||
exhibits a longer memory span over sequences.
|
recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences.
|
||||||
|
|
||||||
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||||
for a good overview of the internals.
|
for a good overview of the internals.
|
||||||
|
@ -176,8 +177,8 @@ Base.show(io::IO, l::GRUCell) =
|
||||||
"""
|
"""
|
||||||
GRU(in::Integer, out::Integer)
|
GRU(in::Integer, out::Integer)
|
||||||
|
|
||||||
Gated Recurrent Unit layer. Behaves like an RNN but generally
|
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078) layer. Behaves like an
|
||||||
exhibits a longer memory span over sequences.
|
RNN but generally exhibits a longer memory span over sequences.
|
||||||
|
|
||||||
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
||||||
for a good overview of the internals.
|
for a good overview of the internals.
|
||||||
|
|
|
@ -2,7 +2,8 @@
|
||||||
"""
|
"""
|
||||||
mae(ŷ, y)
|
mae(ŷ, y)
|
||||||
|
|
||||||
Return the mean of absolute error `sum(abs.(ŷ .- y)) / length(y)`
|
Return the mean of absolute error; calculated as
|
||||||
|
`sum(abs.(ŷ .- y)) / length(y)`.
|
||||||
"""
|
"""
|
||||||
mae(ŷ, y) = sum(abs.(ŷ .- y)) * 1 // length(y)
|
mae(ŷ, y) = sum(abs.(ŷ .- y)) * 1 // length(y)
|
||||||
|
|
||||||
|
@ -10,7 +11,14 @@ mae(ŷ, y) = sum(abs.(ŷ .- y)) * 1 // length(y)
|
||||||
"""
|
"""
|
||||||
mse(ŷ, y)
|
mse(ŷ, y)
|
||||||
|
|
||||||
Return the mean squared error `sum((ŷ .- y).^2) / length(y)`.
|
Return the mean squared error between ŷ and y; calculated as
|
||||||
|
`sum((ŷ .- y).^2) / length(y)`.
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
```jldoctest
|
||||||
|
julia> Flux.mse([0, 2], [1, 1])
|
||||||
|
1//1
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
|
mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
|
||||||
|
|
||||||
|
@ -18,10 +26,11 @@ mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
|
||||||
"""
|
"""
|
||||||
msle(ŷ, y; ϵ=eps(eltype(ŷ)))
|
msle(ŷ, y; ϵ=eps(eltype(ŷ)))
|
||||||
|
|
||||||
Returns the mean of the squared logarithmic errors `sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) / length(y)`.
|
Return the mean of the squared logarithmic errors; calculated as
|
||||||
|
`sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) / length(y)`.
|
||||||
The `ϵ` term provides numerical stability.
|
The `ϵ` term provides numerical stability.
|
||||||
|
|
||||||
This error penalizes an under-predicted estimate greater than an over-predicted estimate.
|
Penalizes an under-predicted estimate greater than an over-predicted estimate.
|
||||||
"""
|
"""
|
||||||
msle(ŷ, y; ϵ=eps(eltype(ŷ))) = sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) * 1 // length(y)
|
msle(ŷ, y; ϵ=eps(eltype(ŷ))) = sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) * 1 // length(y)
|
||||||
|
|
||||||
|
@ -30,13 +39,12 @@ msle(ŷ, y; ϵ=eps(eltype(ŷ))) = sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) *
|
||||||
"""
|
"""
|
||||||
huber_loss(ŷ, y; δ=1.0)
|
huber_loss(ŷ, y; δ=1.0)
|
||||||
|
|
||||||
Computes the mean of the Huber loss given the prediction `ŷ` and true values `y`. By default, δ is set to 1.0.
|
Return the mean of the [Huber loss](https://en.wikipedia.org/wiki/Huber_loss)
|
||||||
|
given the prediction `ŷ` and true values `y`.
|
||||||
|
|
||||||
| 0.5*|ŷ - y|, for |ŷ - y| <= δ
|
| 0.5 * |ŷ - y|, for |ŷ - y| <= δ
|
||||||
Hubber loss = |
|
Huber loss = |
|
||||||
| δ*(|ŷ - y| - 0.5*δ), otherwise
|
| δ * (|ŷ - y| - 0.5 * δ), otherwise
|
||||||
|
|
||||||
[`Huber Loss`](https://en.wikipedia.org/wiki/Huber_loss).
|
|
||||||
"""
|
"""
|
||||||
function huber_loss(ŷ, y; δ=eltype(ŷ)(1))
|
function huber_loss(ŷ, y; δ=eltype(ŷ)(1))
|
||||||
abs_error = abs.(ŷ .- y)
|
abs_error = abs.(ŷ .- y)
|
||||||
|
@ -58,22 +66,40 @@ function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Abstr
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
crossentropy(ŷ, y; weight=1)
|
crossentropy(ŷ, y; weight = nothing)
|
||||||
|
|
||||||
Return the crossentropy computed as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`.
|
Return the cross entropy between the given probability distributions;
|
||||||
|
calculated as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`.
|
||||||
|
|
||||||
See also [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref).
|
`weight` can be `Nothing`, a `Number` or an `AbstractVector`.
|
||||||
|
`weight=nothing` acts like `weight=1` but is faster.
|
||||||
|
|
||||||
|
See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
```jldoctest
|
||||||
|
julia> Flux.crossentropy(softmax([-1.1491, 0.8619, 0.3127]), [1, 1, 0])
|
||||||
|
3.085467254747739
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _crossentropy(ŷ, y, weight)
|
crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _crossentropy(ŷ, y, weight)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
logitcrossentropy(ŷ, y; weight=1)
|
logitcrossentropy(ŷ, y; weight = 1)
|
||||||
|
|
||||||
Return the crossentropy computed after a [softmax](@ref) operation:
|
Return the crossentropy computed after a [`Flux.logsoftmax`](@ref) operation;
|
||||||
|
calculated as `-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)`.
|
||||||
|
|
||||||
-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)
|
`logitcrossentropy(ŷ, y)` is mathematically equivalent to
|
||||||
|
[`Flux.crossentropy(softmax(log(ŷ)), y)`](@ref) but it is more numerically stable.
|
||||||
|
|
||||||
See also [`crossentropy`](@ref), [`binarycrossentropy`](@ref).
|
See also: [`Flux.crossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
```jldoctest
|
||||||
|
julia> Flux.logitcrossentropy([-1.1491, 0.8619, 0.3127], [1, 1, 0])
|
||||||
|
3.085467254747738
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
function logitcrossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
function logitcrossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||||
return -sum(y .* logsoftmax(ŷ) .* weight) * 1 // size(y, 2)
|
return -sum(y .* logsoftmax(ŷ) .* weight) * 1 // size(y, 2)
|
||||||
|
@ -82,9 +108,20 @@ end
|
||||||
"""
|
"""
|
||||||
binarycrossentropy(ŷ, y; ϵ=eps(ŷ))
|
binarycrossentropy(ŷ, y; ϵ=eps(ŷ))
|
||||||
|
|
||||||
Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability.
|
Return ``-y*\\log(ŷ + ϵ) - (1-y)*\\log(1-ŷ + ϵ)``. The `ϵ` term provides numerical stability.
|
||||||
|
|
||||||
Typically, the prediction `ŷ` is given by the output of a [`sigmoid`](@ref) activation.
|
Typically, the prediction `ŷ` is given by the output of a [`sigmoid`](@ref) activation.
|
||||||
|
|
||||||
|
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
```jldoctest
|
||||||
|
julia> Flux.binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0])
|
||||||
|
3-element Array{Float64,1}:
|
||||||
|
1.424397097347566
|
||||||
|
0.35231664672364077
|
||||||
|
0.8616703662235441
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
||||||
|
|
||||||
|
@ -94,10 +131,19 @@ CuArrays.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1
|
||||||
"""
|
"""
|
||||||
logitbinarycrossentropy(ŷ, y)
|
logitbinarycrossentropy(ŷ, y)
|
||||||
|
|
||||||
`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(ŷ), y)`
|
`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to
|
||||||
but it is more numerically stable.
|
[`Flux.binarycrossentropy(σ(log(ŷ)), y)`](@ref) but it is more numerically stable.
|
||||||
|
|
||||||
See also [`binarycrossentropy`](@ref), [`sigmoid`](@ref), [`logsigmoid`](@ref).
|
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref)
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
```jldoctest
|
||||||
|
julia> Flux.logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0])
|
||||||
|
3-element Array{Float64,1}:
|
||||||
|
1.4243970973475661
|
||||||
|
0.35231664672364094
|
||||||
|
0.8616703662235443
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ)
|
logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ)
|
||||||
|
|
||||||
|
@ -107,26 +153,27 @@ CuArrays.@cufunc logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ)
|
||||||
"""
|
"""
|
||||||
normalise(x; dims=1)
|
normalise(x; dims=1)
|
||||||
|
|
||||||
Normalises `x` to mean 0 and standard deviation 1, across the dimensions given by `dims`. Defaults to normalising over columns.
|
Normalise `x` to mean 0 and standard deviation 1 across the dimensions given by `dims`.
|
||||||
|
Defaults to normalising over columns.
|
||||||
|
|
||||||
```julia-repl
|
```jldoctest
|
||||||
julia> a = reshape(collect(1:9), 3, 3)
|
julia> a = reshape(collect(1:9), 3, 3)
|
||||||
3×3 Array{Int64,2}:
|
3×3 Array{Int64,2}:
|
||||||
1 4 7
|
1 4 7
|
||||||
2 5 8
|
2 5 8
|
||||||
3 6 9
|
3 6 9
|
||||||
|
|
||||||
julia> normalise(a)
|
julia> Flux.normalise(a)
|
||||||
3×3 Array{Float64,2}:
|
3×3 Array{Float64,2}:
|
||||||
-1.22474 -1.22474 -1.22474
|
-1.22474 -1.22474 -1.22474
|
||||||
0.0 0.0 0.0
|
0.0 0.0 0.0
|
||||||
1.22474 1.22474 1.22474
|
1.22474 1.22474 1.22474
|
||||||
|
|
||||||
julia> normalise(a, dims=2)
|
julia> Flux.normalise(a, dims=2)
|
||||||
3×3 Array{Float64,2}:
|
3×3 Array{Float64,2}:
|
||||||
-1.22474 0.0 1.22474
|
-1.22474 0.0 1.22474
|
||||||
-1.22474 0.0 1.22474
|
-1.22474 0.0 1.22474
|
||||||
-1.22474 0.0 1.22474
|
-1.22474 0.0 1.22474
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
function normalise(x::AbstractArray; dims=1)
|
function normalise(x::AbstractArray; dims=1)
|
||||||
|
@ -138,10 +185,14 @@ end
|
||||||
"""
|
"""
|
||||||
kldivergence(ŷ, y)
|
kldivergence(ŷ, y)
|
||||||
|
|
||||||
KLDivergence is a measure of how much one probability distribution is different from the other.
|
Return the
|
||||||
It is always non-negative and zero only when both the distributions are equal everywhere.
|
[Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence)
|
||||||
|
between the given probability distributions.
|
||||||
|
|
||||||
[KL Divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence).
|
KL divergence is a measure of how much one probability distribution is different
|
||||||
|
from the other.
|
||||||
|
It is always non-negative and zero only when both the distributions are equal
|
||||||
|
everywhere.
|
||||||
"""
|
"""
|
||||||
function kldivergence(ŷ, y)
|
function kldivergence(ŷ, y)
|
||||||
entropy = sum(y .* log.(y)) * 1 //size(y,2)
|
entropy = sum(y .* log.(y)) * 1 //size(y,2)
|
||||||
|
@ -152,59 +203,60 @@ end
|
||||||
"""
|
"""
|
||||||
poisson(ŷ, y)
|
poisson(ŷ, y)
|
||||||
|
|
||||||
Poisson loss function is a measure of how the predicted distribution diverges from the expected distribution.
|
Return how much the predicted distribution `ŷ` diverges from the expected Poisson
|
||||||
Returns `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`
|
distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`.
|
||||||
|
|
||||||
[Poisson Loss](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
|
[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
|
||||||
"""
|
"""
|
||||||
poisson(ŷ, y) = sum(ŷ .- y .* log.(ŷ)) * 1 // size(y,2)
|
poisson(ŷ, y) = sum(ŷ .- y .* log.(ŷ)) * 1 // size(y,2)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
hinge(ŷ, y)
|
hinge(ŷ, y)
|
||||||
|
|
||||||
Measures the loss given the prediction `ŷ` and true labels `y` (containing 1 or -1).
|
Return the [hinge loss](https://en.wikipedia.org/wiki/Hinge_loss) given the
|
||||||
Returns `sum((max.(0, 1 .- ŷ .* y))) / size(y, 2)`
|
prediction `ŷ` and true labels `y` (containing 1 or -1); calculated as
|
||||||
|
`sum(max.(0, 1 .- ŷ .* y)) / size(y, 2)`.
|
||||||
|
|
||||||
[Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss)
|
See also: [`squared_hinge`](@ref)
|
||||||
See also [`squared_hinge`](@ref).
|
|
||||||
"""
|
"""
|
||||||
hinge(ŷ, y) = sum(max.(0, 1 .- ŷ .* y)) * 1 // size(y, 2)
|
hinge(ŷ, y) = sum(max.(0, 1 .- ŷ .* y)) * 1 // size(y, 2)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
squared_hinge(ŷ, y)
|
squared_hinge(ŷ, y)
|
||||||
|
|
||||||
Computes squared hinge loss given the prediction `ŷ` and true labels `y` (conatining 1 or -1).
|
Return the squared hinge loss given the prediction `ŷ` and true labels `y`
|
||||||
Returns `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`
|
(containing 1 or -1); calculated as `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`.
|
||||||
|
|
||||||
See also [`hinge`](@ref).
|
See also: [`hinge`](@ref)
|
||||||
"""
|
"""
|
||||||
squared_hinge(ŷ, y) = sum((max.(0, 1 .- ŷ .* y)).^2) * 1 // size(y, 2)
|
squared_hinge(ŷ, y) = sum((max.(0, 1 .- ŷ .* y)).^2) * 1 // size(y, 2)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
dice_coeff_loss(ŷ, y; smooth=1)
|
dice_coeff_loss(ŷ, y; smooth=1)
|
||||||
|
|
||||||
Loss function used in Image Segmentation. Calculates loss based on dice coefficient. Similar to F1_score.
|
Return a loss based on the dice coefficient.
|
||||||
Returns `1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth)`
|
Used in the [V-Net](https://arxiv.org/pdf/1606.04797v1.pdf) image segmentation
|
||||||
|
architecture.
|
||||||
[V-Net: Fully Convolutional Neural Networks forVolumetric Medical Image Segmentation](https://arxiv.org/pdf/1606.04797v1.pdf)
|
Similar to the F1_score. Calculated as:
|
||||||
|
1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth)`
|
||||||
"""
|
"""
|
||||||
dice_coeff_loss(ŷ, y; smooth=eltype(ŷ)(1.0)) = 1 - (2*sum(y .* ŷ) + smooth) / (sum(y.^2) + sum(ŷ.^2) + smooth)
|
dice_coeff_loss(ŷ, y; smooth=eltype(ŷ)(1.0)) = 1 - (2*sum(y .* ŷ) + smooth) / (sum(y.^2) + sum(ŷ.^2) + smooth)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
tversky_loss(ŷ, y; β=0.7)
|
tversky_loss(ŷ, y; β=0.7)
|
||||||
|
|
||||||
Used with imbalanced data to give more weightage to False negatives.
|
Return the [Tversky loss](https://arxiv.org/pdf/1706.05721.pdf).
|
||||||
|
Used with imbalanced data to give more weight to false negatives.
|
||||||
Larger β weigh recall higher than precision (by placing more emphasis on false negatives)
|
Larger β weigh recall higher than precision (by placing more emphasis on false negatives)
|
||||||
Returns `1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)`
|
Calculated as:
|
||||||
|
1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
|
||||||
[Tversky loss function for image segmentation using 3D fully convolutional deep networks](https://arxiv.org/pdf/1706.05721.pdf)
|
|
||||||
"""
|
"""
|
||||||
tversky_loss(ŷ, y; β=eltype(ŷ)(0.7)) = 1 - (sum(y .* ŷ) + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
|
tversky_loss(ŷ, y; β=eltype(ŷ)(0.7)) = 1 - (sum(y .* ŷ) + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
flatten(x::AbstractArray)
|
flatten(x::AbstractArray)
|
||||||
|
|
||||||
Transforms (w,h,c,b)-shaped input into (w x h x c,b)-shaped output,
|
Transform (w, h, c, b)-shaped input into (w × h × c, b)-shaped output
|
||||||
by linearizing all values for each element in the batch.
|
by linearizing all values for each element in the batch.
|
||||||
"""
|
"""
|
||||||
function flatten(x::AbstractArray)
|
function flatten(x::AbstractArray)
|
||||||
|
|
|
@ -45,22 +45,20 @@ cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.d
|
||||||
"""
|
"""
|
||||||
onehot(l, labels[, unk])
|
onehot(l, labels[, unk])
|
||||||
|
|
||||||
Create an [`OneHotVector`](@ref) wtih `l`-th element be `true` based on possible `labels` set.
|
Create a `OneHotVector` with its `l`-th element `true` based on the
|
||||||
If `unk` is given, it retruns `onehot(unk, labels)` if the input label `l` is not find in `labels`; otherwise
|
possible set of `labels`.
|
||||||
it will error.
|
If `unk` is given, return `onehot(unk, labels)` if the input label `l` is not found
|
||||||
|
in `labels`; otherwise it will error.
|
||||||
## Examples
|
|
||||||
|
|
||||||
|
# Examples
|
||||||
```jldoctest
|
```jldoctest
|
||||||
julia> using Flux: onehot
|
julia> Flux.onehot(:b, [:a, :b, :c])
|
||||||
|
|
||||||
julia> onehot(:b, [:a, :b, :c])
|
|
||||||
3-element Flux.OneHotVector:
|
3-element Flux.OneHotVector:
|
||||||
0
|
0
|
||||||
1
|
1
|
||||||
0
|
0
|
||||||
|
|
||||||
julia> onehot(:c, [:a, :b, :c])
|
julia> Flux.onehot(:c, [:a, :b, :c])
|
||||||
3-element Flux.OneHotVector:
|
3-element Flux.OneHotVector:
|
||||||
0
|
0
|
||||||
0
|
0
|
||||||
|
@ -82,15 +80,14 @@ end
|
||||||
"""
|
"""
|
||||||
onehotbatch(ls, labels[, unk...])
|
onehotbatch(ls, labels[, unk...])
|
||||||
|
|
||||||
Create an [`OneHotMatrix`](@ref) with a batch of labels based on possible `labels` set, returns the
|
Create a `OneHotMatrix` with a batch of labels based on the
|
||||||
`onehot(unk, labels)` if given labels `ls` is not found in set `labels`.
|
possible set of `labels`.
|
||||||
|
If `unk` is given, return [`onehot(unk, labels)`](@ref) if one of the input
|
||||||
## Examples
|
labels `ls` is not found in `labels`; otherwise it will error.
|
||||||
|
|
||||||
|
# Examples
|
||||||
```jldoctest
|
```jldoctest
|
||||||
julia> using Flux: onehotbatch
|
julia> Flux.onehotbatch([:b, :a, :b], [:a, :b, :c])
|
||||||
|
|
||||||
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
|
|
||||||
3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
|
3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
|
||||||
0 1 0
|
0 1 0
|
||||||
1 0 1
|
1 0 1
|
||||||
|
@ -107,13 +104,12 @@ Base.argmax(xs::OneHotVector) = xs.ix
|
||||||
|
|
||||||
Inverse operations of [`onehot`](@ref).
|
Inverse operations of [`onehot`](@ref).
|
||||||
|
|
||||||
|
# Examples
|
||||||
```jldoctest
|
```jldoctest
|
||||||
julia> using Flux: onecold
|
julia> Flux.onecold([true, false, false], [:a, :b, :c])
|
||||||
|
|
||||||
julia> onecold([true, false, false], [:a, :b, :c])
|
|
||||||
:a
|
:a
|
||||||
|
|
||||||
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
|
julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||||
:c
|
:c
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -6,24 +6,25 @@ const ϵ = 1e-8
|
||||||
# TODO: should use weak refs
|
# TODO: should use weak refs
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Descent(η)
|
Descent(η = 0.1)
|
||||||
|
|
||||||
Classic gradient descent optimiser with learning rate `η`.
|
Classic gradient descent optimiser with learning rate `η`.
|
||||||
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`
|
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`
|
||||||
|
|
||||||
## Parameters
|
# Parameters
|
||||||
- Learning Rate (η): The amount by which the gradients are discounted before updating the weights. Defaults to `0.1`.
|
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||||
|
the weights.
|
||||||
|
|
||||||
## Example
|
# Examples
|
||||||
```julia-repl
|
```julia
|
||||||
opt = Descent() # uses default η (0.1)
|
opt = Descent()
|
||||||
|
|
||||||
opt = Descent(0.3) # use provided η
|
opt = Descent(0.3)
|
||||||
|
|
||||||
ps = params(model)
|
ps = params(model)
|
||||||
|
|
||||||
gs = gradient(ps) do
|
gs = gradient(ps) do
|
||||||
loss(x, y)
|
loss(x, y)
|
||||||
end
|
end
|
||||||
|
|
||||||
Flux.Optimise.update!(opt, ps, gs)
|
Flux.Optimise.update!(opt, ps, gs)
|
||||||
|
@ -40,17 +41,19 @@ function apply!(o::Descent, x, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Momentum(η, ρ)
|
Momentum(η = 0.01, ρ = 0.9)
|
||||||
|
|
||||||
Gradient descent with learning rate `η` and momentum `ρ`.
|
Gradient descent optimizer with learning rate `η` and momentum `ρ`.
|
||||||
|
|
||||||
## Parameters
|
# Parameters
|
||||||
- Learning Rate (`η`): Amount by which gradients are discounted before updating the weights. Defaults to `0.01`.
|
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||||
- Momentum (`ρ`): Parameter that accelerates descent in the relevant direction and dampens oscillations. Defaults to `0.9`.
|
the weights.
|
||||||
|
- Momentum (`ρ`): Controls the acceleration of gradient descent in the
|
||||||
|
prominent direction, in effect dampening oscillations.
|
||||||
|
|
||||||
## Examples
|
# Examples
|
||||||
```julia
|
```julia
|
||||||
opt = Momentum() # uses defaults of η = 0.01 and ρ = 0.9
|
opt = Momentum()
|
||||||
|
|
||||||
opt = Momentum(0.01, 0.99)
|
opt = Momentum(0.01, 0.99)
|
||||||
```
|
```
|
||||||
|
@ -71,17 +74,19 @@ function apply!(o::Momentum, x, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Nesterov(η, ρ)
|
Nesterov(η = 0.001, ρ = 0.9)
|
||||||
|
|
||||||
Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
|
Gradient descent optimizer with learning rate `η` and Nesterov momentum `ρ`.
|
||||||
|
|
||||||
## Parameters
|
# Parameters
|
||||||
- Learning Rate (η): Amount by which the gradients are dicsounted berfore updating the weights. Defaults to `0.001`.
|
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||||
- Nesterov Momentum (ρ): Parameters controlling the amount of nesterov momentum to be applied. Defaults to `0.9`.
|
the weights.
|
||||||
|
- Nesterov momentum (`ρ`): Controls the acceleration of gradient descent in the
|
||||||
|
prominent direction, in effect dampening oscillations.
|
||||||
|
|
||||||
## Examples
|
# Examples
|
||||||
```julia
|
```julia
|
||||||
opt = Nesterov() # uses defaults η = 0.001 and ρ = 0.9
|
opt = Nesterov()
|
||||||
|
|
||||||
opt = Nesterov(0.003, 0.95)
|
opt = Nesterov(0.003, 0.95)
|
||||||
```
|
```
|
||||||
|
@ -103,23 +108,25 @@ function apply!(o::Nesterov, x, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
RMSProp(η, ρ)
|
RMSProp(η = 0.001, ρ = 0.9)
|
||||||
|
|
||||||
Implements the RMSProp algortihm. Often a good choice for recurrent networks. Parameters other than learning rate generally don't need tuning.
|
Optimizer using the
|
||||||
|
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
||||||
|
algorithm. Often a good choice for recurrent networks. Parameters other than learning rate
|
||||||
|
generally don't need tuning.
|
||||||
|
|
||||||
## Parameters
|
# Parameters
|
||||||
- Learning Rate (η): Defaults to `0.001`.
|
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||||
- Rho (ρ): Defaults to `0.9`.
|
the weights.
|
||||||
|
- Momentum (`ρ`): Controls the acceleration of gradient descent in the
|
||||||
|
prominent direction, in effect dampening oscillations.
|
||||||
|
|
||||||
## Examples
|
# Examples
|
||||||
```julia
|
```julia
|
||||||
opt = RMSProp() # uses default η = 0.001 and ρ = 0.9
|
opt = RMSProp()
|
||||||
|
|
||||||
opt = RMSProp(0.002, 0.95)
|
opt = RMSProp(0.002, 0.95)
|
||||||
```
|
```
|
||||||
|
|
||||||
## References
|
|
||||||
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
|
||||||
"""
|
"""
|
||||||
mutable struct RMSProp
|
mutable struct RMSProp
|
||||||
eta::Float64
|
eta::Float64
|
||||||
|
@ -137,23 +144,22 @@ function apply!(o::RMSProp, x, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ADAM(η, β::Tuple)
|
ADAM(η = 0.001, β::Tuple = (0.9, 0.999))
|
||||||
|
|
||||||
Implements the ADAM optimiser.
|
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
||||||
|
|
||||||
## Paramters
|
# Parameters
|
||||||
- Learning Rate (`η`): Defaults to `0.001`.
|
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||||
- Beta (`β::Tuple`): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
the weights.
|
||||||
|
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||||
## Examples
|
second (β2) momentum estimate.
|
||||||
|
|
||||||
|
# Examples
|
||||||
```julia
|
```julia
|
||||||
opt = ADAM() # uses the default η = 0.001 and β = (0.9, 0.999)
|
opt = ADAM()
|
||||||
|
|
||||||
opt = ADAM(0.001, (0.9, 0.8))
|
opt = ADAM(0.001, (0.9, 0.8))
|
||||||
```
|
```
|
||||||
## References
|
|
||||||
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
|
|
||||||
"""
|
"""
|
||||||
mutable struct ADAM
|
mutable struct ADAM
|
||||||
eta::Float64
|
eta::Float64
|
||||||
|
@ -174,24 +180,22 @@ function apply!(o::ADAM, x, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
RADAM(η, β::Tuple)
|
RADAM(η = 0.001, β::Tuple = (0.9, 0.999))
|
||||||
|
|
||||||
Implements the rectified ADAM optimizer.
|
[Rectified ADAM](https://arxiv.org/pdf/1908.03265v1.pdf) optimizer.
|
||||||
|
|
||||||
## Parameters
|
# Parameters
|
||||||
- Learning Rate (η): Defaults to `0.001`
|
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||||
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
the weights.
|
||||||
|
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||||
## Examples
|
second (β2) momentum estimate.
|
||||||
|
|
||||||
|
# Examples
|
||||||
```julia
|
```julia
|
||||||
opt = RADAM() # uses the default η = 0.001 and β = (0.9, 0.999)
|
opt = RADAM()
|
||||||
|
|
||||||
opt = RADAM(0.001, (0.9, 0.8))
|
opt = RADAM(0.001, (0.9, 0.8))
|
||||||
```
|
```
|
||||||
|
|
||||||
## References
|
|
||||||
[RADAM](https://arxiv.org/pdf/1908.03265v1.pdf) optimiser (Rectified ADAM).
|
|
||||||
"""
|
"""
|
||||||
mutable struct RADAM
|
mutable struct RADAM
|
||||||
eta::Float64
|
eta::Float64
|
||||||
|
@ -219,22 +223,22 @@ function apply!(o::RADAM, x, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
AdaMax(η, β::Tuple)
|
AdaMax(η = 0.001, β::Tuple = (0.9, 0.999))
|
||||||
|
|
||||||
Variant of ADAM based on ∞-norm.
|
[AdaMax](https://arxiv.org/abs/1412.6980v9) is a variant of ADAM based on the ∞-norm.
|
||||||
|
|
||||||
## Parameters
|
# Parameters
|
||||||
- Learning Rate (η): Defaults to `0.001`
|
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||||
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
the weights.
|
||||||
|
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||||
|
second (β2) momentum estimate.
|
||||||
|
|
||||||
## Examples
|
# Examples
|
||||||
```julia
|
```julia
|
||||||
opt = AdaMax() # uses default η and β
|
opt = AdaMax()
|
||||||
|
|
||||||
opt = AdaMax(0.001, (0.9, 0.995))
|
opt = AdaMax(0.001, (0.9, 0.995))
|
||||||
```
|
```
|
||||||
## References
|
|
||||||
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser.
|
|
||||||
"""
|
"""
|
||||||
mutable struct AdaMax
|
mutable struct AdaMax
|
||||||
eta::Float64
|
eta::Float64
|
||||||
|
@ -255,23 +259,22 @@ function apply!(o::AdaMax, x, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ADAGrad(η)
|
ADAGrad(η = 0.1)
|
||||||
|
|
||||||
Implements AdaGrad. It has parameter specific learning rates based on how frequently it is updated.
|
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimizer. It has
|
||||||
|
parameter specific learning rates based on how frequently it is updated.
|
||||||
|
Parameters don't need tuning.
|
||||||
|
|
||||||
## Parameters
|
# Parameters
|
||||||
- Learning Rate (η): Defaults to `0.1`
|
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||||
|
the weights.
|
||||||
|
|
||||||
## Examples
|
# Examples
|
||||||
```julia
|
```julia
|
||||||
opt = ADAGrad() # uses default η = 0.1
|
opt = ADAGrad()
|
||||||
|
|
||||||
opt = ADAGrad(0.001)
|
opt = ADAGrad(0.001)
|
||||||
```
|
```
|
||||||
|
|
||||||
## References
|
|
||||||
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
|
|
||||||
Parameters don't need tuning.
|
|
||||||
"""
|
"""
|
||||||
mutable struct ADAGrad
|
mutable struct ADAGrad
|
||||||
eta::Float64
|
eta::Float64
|
||||||
|
@ -288,21 +291,21 @@ function apply!(o::ADAGrad, x, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ADADelta(ρ)
|
ADADelta(ρ = 0.9)
|
||||||
|
|
||||||
Version of ADAGrad that adapts learning rate based on a window of past gradient updates. Parameters don't need tuning.
|
[ADADelta](https://arxiv.org/abs/1212.5701) is a version of ADAGrad adapting its learning
|
||||||
|
rate based on a window of past gradient updates.
|
||||||
|
Parameters don't need tuning.
|
||||||
|
|
||||||
## Parameters
|
# Parameters
|
||||||
- Rho (ρ): Factor by which gradient is decayed at each time step. Defaults to `0.9`.
|
- Rho (`ρ`): Factor by which the gradient is decayed at each time step.
|
||||||
|
|
||||||
## Examples
|
# Examples
|
||||||
```julia
|
```julia
|
||||||
opt = ADADelta() # uses default ρ = 0.9
|
opt = ADADelta()
|
||||||
|
|
||||||
opt = ADADelta(0.89)
|
opt = ADADelta(0.89)
|
||||||
```
|
```
|
||||||
|
|
||||||
## References
|
|
||||||
[ADADelta](https://arxiv.org/abs/1212.5701) optimiser.
|
|
||||||
"""
|
"""
|
||||||
mutable struct ADADelta
|
mutable struct ADADelta
|
||||||
rho::Float64
|
rho::Float64
|
||||||
|
@ -321,22 +324,23 @@ function apply!(o::ADADelta, x, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
AMSGrad(η, β::Tuple)
|
AMSGrad(η = 0.001, β::Tuple = (0.9, 0.999))
|
||||||
|
|
||||||
Implements AMSGrad version of the ADAM optimiser. Parameters don't need tuning.
|
The [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) version of the ADAM
|
||||||
|
optimiser. Parameters don't need tuning.
|
||||||
|
|
||||||
## Parameters
|
# Parameters
|
||||||
- Learning Rate (η): Defaults to `0.001`.
|
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||||
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
the weights.
|
||||||
|
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||||
|
second (β2) momentum estimate.
|
||||||
|
|
||||||
## Examples
|
# Examples
|
||||||
```julia
|
```julia
|
||||||
opt = AMSGrad() # uses default η and β
|
opt = AMSGrad()
|
||||||
|
|
||||||
opt = AMSGrad(0.001, (0.89, 0.995))
|
opt = AMSGrad(0.001, (0.89, 0.995))
|
||||||
```
|
```
|
||||||
|
|
||||||
## References
|
|
||||||
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser.
|
|
||||||
"""
|
"""
|
||||||
mutable struct AMSGrad
|
mutable struct AMSGrad
|
||||||
eta::Float64
|
eta::Float64
|
||||||
|
@ -356,22 +360,23 @@ function apply!(o::AMSGrad, x, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
NADAM(η, β::Tuple)
|
NADAM(η = 0.001, β::Tuple = (0.9, 0.999))
|
||||||
|
|
||||||
Nesterov variant of ADAM. Parameters don't need tuning.
|
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) is a Nesterov variant of ADAM.
|
||||||
|
Parameters don't need tuning.
|
||||||
|
|
||||||
## Parameters
|
# Parameters
|
||||||
- Learning Rate (η): Defaults to `0.001`.
|
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||||
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to `(0.9, 0.999)`.
|
the weights.
|
||||||
|
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||||
|
second (β2) momentum estimate.
|
||||||
|
|
||||||
## Examples
|
# Examples
|
||||||
```julia
|
```julia
|
||||||
opt = NADAM() # uses default η and β
|
opt = NADAM()
|
||||||
|
|
||||||
opt = NADAM(0.002, (0.89, 0.995))
|
opt = NADAM(0.002, (0.89, 0.995))
|
||||||
```
|
```
|
||||||
|
|
||||||
## References
|
|
||||||
[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser.
|
|
||||||
"""
|
"""
|
||||||
mutable struct NADAM
|
mutable struct NADAM
|
||||||
eta::Float64
|
eta::Float64
|
||||||
|
@ -392,23 +397,24 @@ function apply!(o::NADAM, x, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ADAMW(η, β::Tuple, decay)
|
ADAMW(η = 0.001, β::Tuple = (0.9, 0.999), decay = 0)
|
||||||
|
|
||||||
Variant of ADAM defined by fixing weight decay regularization.
|
[ADAMW](https://arxiv.org/abs/1711.05101) is a variant of ADAM fixing (as in repairing) its
|
||||||
|
weight decay regularization.
|
||||||
|
|
||||||
## Parameters
|
# Parameters
|
||||||
- Learning Rate (η): Defaults to `0.001`.
|
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||||
- Beta (β::Tuple): The first element refers to β1 and the second to β2. Defaults to (0.9, 0.999).
|
the weights.
|
||||||
- decay: Decay applied to weights during optimisation. Defaults to 0.
|
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
|
||||||
|
second (β2) momentum estimate.
|
||||||
|
- `decay`: Decay applied to weights during optimisation.
|
||||||
|
|
||||||
## Examples
|
# Examples
|
||||||
```julia
|
```julia
|
||||||
opt = ADAMW() # uses default η, β and decay
|
opt = ADAMW()
|
||||||
|
|
||||||
opt = ADAMW(0.001, (0.89, 0.995), 0.1)
|
opt = ADAMW(0.001, (0.89, 0.995), 0.1)
|
||||||
```
|
```
|
||||||
|
|
||||||
## References
|
|
||||||
[ADAMW](https://arxiv.org/abs/1711.05101)
|
|
||||||
"""
|
"""
|
||||||
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
|
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
|
||||||
Optimiser(ADAM(η, β), WeightDecay(decay))
|
Optimiser(ADAM(η, β), WeightDecay(decay))
|
||||||
|
@ -441,14 +447,13 @@ function apply!(o::Optimiser, x, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
InvDecay(γ)
|
InvDecay(γ = 0.001)
|
||||||
|
|
||||||
Applies inverse time decay to an optimiser, i.e., the effective step size at iteration `n` is `eta / (1 + γ * n)` where `eta` is the initial step size. The wrapped optimiser's step size is not modified.
|
Apply inverse time decay to an optimiser, so that the effective step size at
|
||||||
|
iteration `n` is `eta / (1 + γ * n)` where `eta` is the initial step size.
|
||||||
|
The wrapped optimiser's step size is not modified.
|
||||||
|
|
||||||
## Parameters
|
# Examples
|
||||||
- gamma (γ): Defaults to `0.001`
|
|
||||||
|
|
||||||
## Example
|
|
||||||
```julia
|
```julia
|
||||||
Optimiser(InvDecay(..), Opt(..))
|
Optimiser(InvDecay(..), Opt(..))
|
||||||
```
|
```
|
||||||
|
@ -469,20 +474,24 @@ function apply!(o::InvDecay, x, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ExpDecay(eta, decay, decay_step, clip)
|
ExpDecay(η = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4)
|
||||||
|
|
||||||
Discount the learning rate `eta` by a multiplicative factor `decay` every `decay_step` till a minimum of `clip`.
|
Discount the learning rate `η` by the factor `decay` every `decay_step` steps till
|
||||||
|
a minimum of `clip`.
|
||||||
|
|
||||||
## Parameters
|
# Parameters
|
||||||
- Learning Rate (eta): Defaults to `0.001`.
|
- Learning rate (`η`): Amount by which gradients are discounted before updating
|
||||||
- decay: Factor by which the learning rate is discounted. Defaults to `0.1`.
|
the weights.
|
||||||
- decay_step: Schedules decay operations by setting number of steps between two decay operations. Defaults to `1000`.
|
- `decay`: Factor by which the learning rate is discounted.
|
||||||
- clip: Minimum value of learning rate. Defaults to `1e-4`.
|
- `decay_step`: Schedule decay operations by setting the number of steps between
|
||||||
|
two decay operations.
|
||||||
|
- `clip`: Minimum value of learning rate.
|
||||||
|
|
||||||
## Example
|
# Examples
|
||||||
To apply exponential decay to an optimiser:
|
To apply exponential decay to an optimiser:
|
||||||
```julia
|
```julia
|
||||||
Optimiser(ExpDecay(..), Opt(..))
|
Optimiser(ExpDecay(..), Opt(..))
|
||||||
|
|
||||||
opt = Optimiser(ExpDecay(), ADAM())
|
opt = Optimiser(ExpDecay(), ADAM())
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
@ -507,12 +516,12 @@ function apply!(o::ExpDecay, x, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
WeightDecay(wd)
|
WeightDecay(wd = 0)
|
||||||
|
|
||||||
Decays the weight by `wd`
|
Decay weights by `wd`.
|
||||||
|
|
||||||
## Parameters
|
# Parameters
|
||||||
- weight decay (wd): 0
|
- Weight decay (`wd`)
|
||||||
"""
|
"""
|
||||||
mutable struct WeightDecay
|
mutable struct WeightDecay
|
||||||
wd::Real
|
wd::Real
|
||||||
|
|
|
@ -2,23 +2,25 @@ using Juno
|
||||||
import Zygote: Params, gradient
|
import Zygote: Params, gradient
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
update!(opt, p, g)
|
|
||||||
update!(opt, ps::Params, gs)
|
|
||||||
|
|
||||||
Perform an update step of the parameters `ps` (or the single parameter `p`)
|
|
||||||
according to optimizer `opt` and the gradients `gs` (the gradient `g`).
|
|
||||||
|
|
||||||
As a result, the parameters are mutated and the optimizer's internal state may change.
|
|
||||||
|
|
||||||
update!(x, x̄)
|
update!(x, x̄)
|
||||||
|
|
||||||
Update the array `x` according to `x .-= x̄`.
|
Update the array `x` according to `x .-= x̄`.
|
||||||
"""
|
"""
|
||||||
function update!(x::AbstractArray, x̄)
|
function update!(x::AbstractArray, x̄)
|
||||||
x .-= x̄
|
x .-= x̄
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
update!(opt, p, g)
|
||||||
|
update!(opt, ps::Params, gs)
|
||||||
|
|
||||||
|
Perform an update step of the parameters `ps` (or the single parameter `p`)
|
||||||
|
according to optimizer `opt` and the gradients `gs` (the gradient `g`).
|
||||||
|
|
||||||
|
As a result, the parameters are mutated and the optimizer's internal state may change.
|
||||||
|
"""
|
||||||
function update!(opt, x, x̄)
|
function update!(opt, x, x̄)
|
||||||
x .-= apply!(opt, x, x̄)
|
x .-= apply!(opt, x, x̄)
|
||||||
end
|
end
|
||||||
|
@ -41,11 +43,10 @@ struct StopException <: Exception end
|
||||||
stop()
|
stop()
|
||||||
|
|
||||||
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
|
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
|
||||||
This would trigger the train loop to stop and exit.
|
This will trigger the train loop to stop and exit.
|
||||||
|
|
||||||
|
# Examples
|
||||||
```julia
|
```julia
|
||||||
# Example callback:
|
|
||||||
|
|
||||||
cb = function ()
|
cb = function ()
|
||||||
accuracy() > 0.9 && Flux.stop()
|
accuracy() > 0.9 && Flux.stop()
|
||||||
end
|
end
|
||||||
|
@ -58,19 +59,19 @@ end
|
||||||
"""
|
"""
|
||||||
train!(loss, params, data, opt; cb)
|
train!(loss, params, data, opt; cb)
|
||||||
|
|
||||||
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
|
For each datapoint `d` in `data` compute the gradient of `loss(d...)` through
|
||||||
backpropagation and calls the optimizer `opt`.
|
backpropagation and call the optimizer `opt`.
|
||||||
|
|
||||||
In case datapoints `d` are of numeric array type, assumes no splatting is needed
|
In case datapoints `d` are of numeric array type, assume no splatting is needed
|
||||||
and computes the gradient of `loss(d)`.
|
and compute the gradient of `loss(d)`.
|
||||||
|
|
||||||
Takes a callback as keyword argument `cb`. For example, this will print "training"
|
A callback is given with the keyword argument `cb`. For example, this will print
|
||||||
every 10 seconds:
|
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
|
||||||
|
|
||||||
train!(loss, params, data, opt,
|
train!(loss, params, data, opt,
|
||||||
cb = throttle(() -> println("training"), 10))
|
cb = throttle(() -> println("training"), 10))
|
||||||
|
|
||||||
The callback can call `Flux.stop()` to interrupt the training loop.
|
The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
|
||||||
|
|
||||||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||||
"""
|
"""
|
||||||
|
@ -106,11 +107,12 @@ end
|
||||||
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
|
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
|
||||||
training in a REPL.
|
training in a REPL.
|
||||||
|
|
||||||
```julia
|
# Examples
|
||||||
julia> @epochs 2 println("hello")
|
```jldoctest
|
||||||
INFO: Epoch 1
|
julia> Flux.@epochs 2 println("hello")
|
||||||
|
[ Info: Epoch 1
|
||||||
hello
|
hello
|
||||||
INFO: Epoch 2
|
[ Info: Epoch 2
|
||||||
hello
|
hello
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
173
src/utils.jl
173
src/utils.jl
|
@ -1,10 +1,40 @@
|
||||||
# Arrays
|
# Arrays
|
||||||
nfan() = 1, 1 #fan_in, fan_out
|
nfan() = 1, 1 # fan_in, fan_out
|
||||||
nfan(n) = 1, n #A vector is treated as a n×1 matrix
|
nfan(n) = 1, n # A vector is treated as a n×1 matrix
|
||||||
nfan(n_out, n_in) = n_in, n_out #In case of Dense kernels: arranged as matrices
|
nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices
|
||||||
nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) #In case of convolution kernels
|
nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of convolution kernels
|
||||||
|
|
||||||
|
"""
|
||||||
|
glorot_uniform(dims...)
|
||||||
|
|
||||||
|
Return an `Array` of size `dims` containing random variables taken from a uniform
|
||||||
|
distribution in the interval ``[-x, x]``, where `x = sqrt(24 / sum(dims)) / 2`.
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
```jldoctest; setup = :(using Random; Random.seed!(0))
|
||||||
|
julia> Flux.glorot_uniform(2, 3)
|
||||||
|
2×3 Array{Float32,2}:
|
||||||
|
0.601094 -0.57414 -0.814925
|
||||||
|
0.900868 0.805994 0.057514
|
||||||
|
```
|
||||||
|
"""
|
||||||
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
|
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
|
||||||
|
|
||||||
|
"""
|
||||||
|
glorot_normal(dims...)
|
||||||
|
|
||||||
|
Return an `Array` of size `dims` containing random variables taken from a normal
|
||||||
|
distribution with mean 0 and standard deviation `(2 / sum(dims))`.
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
```jldoctest; setup = :(using Random; Random.seed!(0))
|
||||||
|
julia> Flux.glorot_normal(3, 2)
|
||||||
|
3×2 Array{Float32,2}:
|
||||||
|
0.429505 -0.0852891
|
||||||
|
0.523935 0.371009
|
||||||
|
-0.223261 0.188052
|
||||||
|
```
|
||||||
|
"""
|
||||||
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
|
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
|
||||||
|
|
||||||
ones(T::Type, dims...) = Base.ones(T, dims...)
|
ones(T::Type, dims...) = Base.ones(T, dims...)
|
||||||
|
@ -13,9 +43,81 @@ zeros(T::Type, dims...) = Base.zeros(T, dims...)
|
||||||
ones(dims...) = Base.ones(Float32, dims...)
|
ones(dims...) = Base.ones(Float32, dims...)
|
||||||
zeros(dims...) = Base.zeros(Float32, dims...)
|
zeros(dims...) = Base.zeros(Float32, dims...)
|
||||||
|
|
||||||
|
"""
|
||||||
|
unsqueeze(xs, dim)
|
||||||
|
|
||||||
|
Return `xs` reshaped into an `Array` one dimensionality higher than `xs`,
|
||||||
|
where `dim` indicates in which dimension `xs` is extended.
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
```jldoctest
|
||||||
|
julia> xs = [[1, 2], [3, 4], [5, 6]]
|
||||||
|
3-element Array{Array{Int64,1},1}:
|
||||||
|
[1, 2]
|
||||||
|
[3, 4]
|
||||||
|
[5, 6]
|
||||||
|
|
||||||
|
julia> Flux.unsqueeze(xs, 1)
|
||||||
|
1×3 Array{Array{Int64,1},2}:
|
||||||
|
[1, 2] [3, 4] [5, 6]
|
||||||
|
|
||||||
|
julia> Flux.unsqueeze([1 2; 3 4], 2)
|
||||||
|
2×1×2 Array{Int64,3}:
|
||||||
|
[:, :, 1] =
|
||||||
|
1
|
||||||
|
3
|
||||||
|
|
||||||
|
[:, :, 2] =
|
||||||
|
2
|
||||||
|
4
|
||||||
|
```
|
||||||
|
"""
|
||||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||||
|
|
||||||
|
"""
|
||||||
|
stack(xs, dim)
|
||||||
|
|
||||||
|
Concatenate the given `Array` of `Array`s `xs` into a single `Array` along the
|
||||||
|
given dimension `dim`.
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
```jldoctest
|
||||||
|
julia> xs = [[1, 2], [3, 4], [5, 6]]
|
||||||
|
3-element Array{Array{Int64,1},1}:
|
||||||
|
[1, 2]
|
||||||
|
[3, 4]
|
||||||
|
[5, 6]
|
||||||
|
|
||||||
|
julia> Flux.stack(xs, 1)
|
||||||
|
3×2 Array{Int64,2}:
|
||||||
|
1 2
|
||||||
|
3 4
|
||||||
|
5 6
|
||||||
|
|
||||||
|
julia> cat(xs, dims=1)
|
||||||
|
3-element Array{Array{Int64,1},1}:
|
||||||
|
[1, 2]
|
||||||
|
[3, 4]
|
||||||
|
[5, 6]
|
||||||
|
```
|
||||||
|
"""
|
||||||
stack(xs, dim) = cat(unsqueeze.(xs, dim)..., dims=dim)
|
stack(xs, dim) = cat(unsqueeze.(xs, dim)..., dims=dim)
|
||||||
|
|
||||||
|
"""
|
||||||
|
unstack(xs, dim)
|
||||||
|
|
||||||
|
Unroll the given `xs` into an `Array` of `Array`s along the given dimension `dim`.
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
```jldoctest
|
||||||
|
julia> Flux.unstack([1 3 5 7; 2 4 6 8], 2)
|
||||||
|
4-element Array{Array{Int64,1},1}:
|
||||||
|
[1, 2]
|
||||||
|
[3, 4]
|
||||||
|
[5, 6]
|
||||||
|
[7, 8]
|
||||||
|
```
|
||||||
|
"""
|
||||||
unstack(xs, dim) = [copy(selectdim(xs, dim, i)) for i in 1:size(xs, dim)]
|
unstack(xs, dim) = [copy(selectdim(xs, dim, i)) for i in 1:size(xs, dim)]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -23,9 +125,16 @@ unstack(xs, dim) = [copy(selectdim(xs, dim, i)) for i in 1:size(xs, dim)]
|
||||||
|
|
||||||
Split `xs` into `n` parts.
|
Split `xs` into `n` parts.
|
||||||
|
|
||||||
```julia
|
# Examples
|
||||||
julia> chunk(1:10, 3)
|
```jldoctest
|
||||||
3-element Array{Array{Int64,1},1}:
|
julia> Flux.chunk(1:10, 3)
|
||||||
|
3-element Array{UnitRange{Int64},1}:
|
||||||
|
1:4
|
||||||
|
5:8
|
||||||
|
9:10
|
||||||
|
|
||||||
|
julia> Flux.chunk(collect(1:10), 3)
|
||||||
|
3-element Array{SubArray{Int64,1,Array{Int64,1},Tuple{UnitRange{Int64}},true},1}:
|
||||||
[1, 2, 3, 4]
|
[1, 2, 3, 4]
|
||||||
[5, 6, 7, 8]
|
[5, 6, 7, 8]
|
||||||
[9, 10]
|
[9, 10]
|
||||||
|
@ -40,11 +149,12 @@ batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)
|
||||||
|
|
||||||
Count the number of times that each element of `xs` appears.
|
Count the number of times that each element of `xs` appears.
|
||||||
|
|
||||||
```julia
|
# Examples
|
||||||
julia> frequencies(['a','b','b'])
|
```jldoctest
|
||||||
|
julia> Flux.frequencies(['a','b','b'])
|
||||||
Dict{Char,Int64} with 2 entries:
|
Dict{Char,Int64} with 2 entries:
|
||||||
'b' => 2
|
|
||||||
'a' => 1
|
'a' => 1
|
||||||
|
'b' => 2
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
function frequencies(xs)
|
function frequencies(xs)
|
||||||
|
@ -64,8 +174,9 @@ squeezebatch(x) = reshape(x, head(size(x)))
|
||||||
|
|
||||||
Batch the arrays in `xs` into a single array.
|
Batch the arrays in `xs` into a single array.
|
||||||
|
|
||||||
```julia
|
# Examples
|
||||||
julia> batch([[1,2,3],[4,5,6]])
|
```jldoctest
|
||||||
|
julia> Flux.batch([[1,2,3],[4,5,6]])
|
||||||
3×2 Array{Int64,2}:
|
3×2 Array{Int64,2}:
|
||||||
1 4
|
1 4
|
||||||
2 5
|
2 5
|
||||||
|
@ -82,6 +193,25 @@ function batch(xs)
|
||||||
return data
|
return data
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
Return the given sequence padded with `p` up to a maximum length of `n`.
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
```jldoctest
|
||||||
|
julia> rpad([1, 2], 4, 0)
|
||||||
|
4-element Array{Int64,1}:
|
||||||
|
1
|
||||||
|
2
|
||||||
|
0
|
||||||
|
0
|
||||||
|
|
||||||
|
julia> rpad([1, 2, 3], 2, 0)
|
||||||
|
3-element Array{Int64,1}:
|
||||||
|
1
|
||||||
|
2
|
||||||
|
3
|
||||||
|
```
|
||||||
|
"""
|
||||||
Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))]
|
Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -90,8 +220,9 @@ Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))
|
||||||
Take a list of `N` sequences, and turn them into a single sequence where each
|
Take a list of `N` sequences, and turn them into a single sequence where each
|
||||||
item is a batch of `N`. Short sequences will be padded by `pad`.
|
item is a batch of `N`. Short sequences will be padded by `pad`.
|
||||||
|
|
||||||
```julia
|
# Examples
|
||||||
julia> batchseq([[1, 2, 3], [4, 5]], 0)
|
```jldoctest
|
||||||
|
julia> Flux.batchseq([[1, 2, 3], [4, 5]], 0)
|
||||||
3-element Array{Array{Int64,1},1}:
|
3-element Array{Array{Int64,1},1}:
|
||||||
[1, 4]
|
[1, 4]
|
||||||
[2, 5]
|
[2, 5]
|
||||||
|
@ -148,11 +279,15 @@ end
|
||||||
# Other
|
# Other
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Returns a function that when invoked, will only be triggered at most once
|
throttle(f, timeout; leading=true, trailing=false)
|
||||||
during `timeout` seconds. Normally, the throttled function will run
|
|
||||||
as much as it can, without ever going more than once per `wait` duration;
|
Return a function that when invoked, will only be triggered at most once
|
||||||
but if you'd like to disable the execution on the leading edge, pass
|
during `timeout` seconds.
|
||||||
`leading=false`. To enable execution on the trailing edge, ditto.
|
|
||||||
|
Normally, the throttled function will run as much as it can, without ever
|
||||||
|
going more than once per `wait` duration; but if you'd like to disable the
|
||||||
|
execution on the leading edge, pass `leading=false`. To enable execution on
|
||||||
|
the trailing edge, pass `trailing=true`.
|
||||||
"""
|
"""
|
||||||
function throttle(f, timeout; leading=true, trailing=false)
|
function throttle(f, timeout; leading=true, trailing=false)
|
||||||
cooldown = true
|
cooldown = true
|
||||||
|
|
|
@ -41,7 +41,8 @@ Random.seed!(0)
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Docs" begin
|
@testset "Docs" begin
|
||||||
if VERSION >= v"1.2"
|
if VERSION >= v"1.4"
|
||||||
|
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
|
||||||
doctest(Flux)
|
doctest(Flux)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue