explanations about params in `train!`

This commit is contained in:
Helios De Rosario 2019-11-14 16:22:31 +01:00 committed by GitHub
parent 074eb47246
commit ba4e3be0d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 1 deletions

View File

@ -1,8 +1,9 @@
# Training
To actually train a model we need three things, in addition to the tracked parameters that will be fitted:
To actually train a model we need four things:
* A *objective function*, that evaluates how well a model is doing given some input data.
* The parameters of the model.
* A collection of data points that will be provided to the objective function.
* An [optimiser](optimisers.md) that will update the model parameters appropriately.
@ -34,6 +35,12 @@ The objective will almost always be defined in terms of some *cost function* tha
At first glance it may seem strange that the model that we want to train is not part of the input arguments of `Flux.train!` too. However the target of the optimizer is not the model itself, but the objective function that represents the departure between modelled and observed data. In other words, the model is implicitly defined in the objective function, and there is no need to give it explicitly. Passing the objective function instead of the model and a cost function separately provides more flexibility, and the possibility of optimizing the calculations.
## Model parameters
The model to be trained must have a set of tracked parameters that are used to calculate the gradients of the objective function. In the [basics](../models/basics.md) section it is explained how to create models with such parameters. The second argument of the function `Flux.train!` must be an object containing those parameters, which can be obtained from a model `m` as `params(m)`.
Such an object contains a reference to the model's parameters, not a copy, such that after their training, the model behaves according to their updated values.
## Datasets
The `data` argument provides a collection of data to train with (usually a set of inputs `x` and target outputs `y`). For example, here's a dummy data set with only one data point: