Flux.jl/docs/src/training/training.md

105 lines
3.7 KiB
Markdown
Raw Normal View History

2017-09-11 12:40:11 +00:00
# Training
To actually train a model we need three things, in addition to the tracked parameters that will be fitted:
2017-09-11 12:06:53 +00:00
2018-02-13 17:08:13 +00:00
* A *objective function*, that evaluates how well a model is doing given some input data.
* A collection of data points that will be provided to the objective function.
2017-09-12 10:34:04 +00:00
* An [optimiser](optimisers.md) that will update the model parameters appropriately.
2017-09-11 12:06:53 +00:00
With these we can call `Flux.train!`:
2017-09-10 01:01:19 +00:00
```julia
2018-11-12 12:12:52 +00:00
Flux.train!(objective, params, data, opt)
2017-09-11 12:06:53 +00:00
```
At first glance it may seem strange that the model that we want to train is not part of the input arguments of `Flux.train!`. However the target of the optimizer is not the model itself, but the objective function that represents the departure between modelled and observed data. In other words, the model is implicitly defined in the objective function, and there is no need to give it explicitly. Passing the objective function instead of the model and a cost function separately (see below) provides more flexibility, and the possibility of optimizing the calculations.
2017-09-11 12:06:53 +00:00
There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo).
## Loss Functions
2018-02-13 17:08:13 +00:00
The objective function must return a number representing how far the model is from its target the *loss* of the model. The `loss` function that we defined in [basics](../models/basics.md) will work as an objective. We can also define an objective in terms of some model:
2017-09-11 12:06:53 +00:00
```julia
m = Chain(
Dense(784, 32, σ),
Dense(32, 10), softmax)
loss(x, y) = Flux.mse(m(x), y)
2018-12-01 11:29:27 +00:00
ps = Flux.params(m)
2017-10-03 18:00:42 +00:00
# later
2018-12-01 11:29:27 +00:00
Flux.train!(loss, ps, data, opt)
2017-09-11 12:06:53 +00:00
```
2018-02-13 17:08:13 +00:00
The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
2017-09-11 12:06:53 +00:00
2017-10-10 11:31:58 +00:00
## Datasets
2017-10-10 12:40:01 +00:00
The `data` argument provides a collection of data to train with (usually a set of inputs `x` and target outputs `y`). For example, here's a dummy data set with only one data point:
2017-10-10 11:31:58 +00:00
```julia
x = rand(784)
y = rand(10)
data = [(x, y)]
```
`Flux.train!` will call `loss(x, y)`, calculate gradients, update the weights and then move on to the next data point if there is one. We can train the model on the same data three times:
```julia
data = [(x, y), (x, y), (x, y)]
# Or equivalently
data = Iterators.repeated((x, y), 3)
```
It's common to load the `x`s and `y`s separately. In this case you can use `zip`:
```julia
xs = [rand(784), rand(784), rand(784)]
ys = [rand( 10), rand( 10), rand( 10)]
data = zip(xs, ys)
```
2018-02-16 12:22:53 +00:00
Note that, by default, `train!` only loops over the data once (a single "epoch").
A convenient way to run multiple epochs from the REPL is provided by `@epochs`.
```julia
julia> using Flux: @epochs
julia> @epochs 2 println("hello")
INFO: Epoch 1
hello
INFO: Epoch 2
hello
julia> @epochs 2 Flux.train!(...)
# Train for two epochs
```
2017-09-11 12:06:53 +00:00
## Callbacks
`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example:
```julia
2018-12-01 11:29:27 +00:00
train!(objective, ps, data, opt, cb = () -> println("training"))
2017-09-11 12:06:53 +00:00
```
Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds.
A more typical callback might look like this:
```julia
test_x, test_y = # ... create single batch of test data ...
evalcb() = @show(loss(test_x, test_y))
2018-12-01 11:29:27 +00:00
Flux.train!(objective, ps, data, opt,
2017-09-11 12:06:53 +00:00
cb = throttle(evalcb, 5))
2017-09-10 01:01:19 +00:00
```
2019-03-12 02:52:05 +00:00
Calling `Flux.stop()` in a callback will exit the training loop early.
```julia
cb = function ()
accuracy() > 0.9 && Flux.stop()
end
```