Flux.jl/docs/src/training/training.md

80 lines
2.4 KiB
Markdown
Raw Normal View History

2017-09-11 12:40:11 +00:00
# Training
2017-09-11 12:06:53 +00:00
To actually train a model we need three things:
2017-09-22 14:27:06 +00:00
* A *model loss function*, that evaluates how well a model is doing given some input data.
2017-09-11 12:06:53 +00:00
* A collection of data points that will be provided to the loss function.
2017-09-12 10:34:04 +00:00
* An [optimiser](optimisers.md) that will update the model parameters appropriately.
2017-09-11 12:06:53 +00:00
With these we can call `Flux.train!`:
2017-09-10 01:01:19 +00:00
```julia
2017-09-27 17:37:07 +00:00
Flux.train!(modelLoss, data, opt)
2017-09-11 12:06:53 +00:00
```
There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo).
## Loss Functions
2017-09-12 10:34:04 +00:00
The `loss` that we defined in [basics](../models/basics.md) is completely valid for training. We can also define a loss in terms of some model:
2017-09-11 12:06:53 +00:00
```julia
m = Chain(
Dense(784, 32, σ),
Dense(32, 10), softmax)
2017-09-22 14:27:06 +00:00
# Model loss function
2017-09-11 12:06:53 +00:00
loss(x, y) = Flux.mse(m(x), y)
2017-10-03 18:00:42 +00:00
# later
Flux.train!(loss, data, opt)
2017-09-11 12:06:53 +00:00
```
2017-10-17 16:36:18 +00:00
The loss will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
2017-09-11 12:06:53 +00:00
2017-10-10 11:31:58 +00:00
## Datasets
2017-10-10 12:40:01 +00:00
The `data` argument provides a collection of data to train with (usually a set of inputs `x` and target outputs `y`). For example, here's a dummy data set with only one data point:
2017-10-10 11:31:58 +00:00
```julia
x = rand(784)
y = rand(10)
data = [(x, y)]
```
`Flux.train!` will call `loss(x, y)`, calculate gradients, update the weights and then move on to the next data point if there is one. We can train the model on the same data three times:
```julia
data = [(x, y), (x, y), (x, y)]
# Or equivalently
data = Iterators.repeated((x, y), 3)
```
It's common to load the `x`s and `y`s separately. In this case you can use `zip`:
```julia
xs = [rand(784), rand(784), rand(784)]
ys = [rand( 10), rand( 10), rand( 10)]
data = zip(xs, ys)
```
2017-09-11 12:06:53 +00:00
## Callbacks
`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example:
```julia
train!(loss, data, opt, cb = () -> println("training"))
```
Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds.
A more typical callback might look like this:
```julia
test_x, test_y = # ... create single batch of test data ...
evalcb() = @show(loss(test_x, test_y))
Flux.train!(loss, data, opt,
cb = throttle(evalcb, 5))
2017-09-10 01:01:19 +00:00
```