Flux.jl/docs/src/training/training.md
Mike J Innes c51f5afb3d clarity
2017-09-27 18:37:07 +01:00

51 lines
1.7 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Training
To actually train a model we need three things:
* A *model loss function*, that evaluates how well a model is doing given some input data.
* A collection of data points that will be provided to the loss function.
* An [optimiser](optimisers.md) that will update the model parameters appropriately.
With these we can call `Flux.train!`:
```julia
Flux.train!(modelLoss, data, opt)
```
There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo).
## Loss Functions
The `loss` that we defined in [basics](../models/basics.md) is completely valid for training. We can also define a loss in terms of some model:
```julia
m = Chain(
Dense(784, 32, σ),
Dense(32, 10), softmax)
# Model loss function
loss(x, y) = Flux.mse(m(x), y)
```
The loss will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `logloss` for cross entropy loss, but you can calculate it however you want.
## Callbacks
`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example:
```julia
train!(loss, data, opt, cb = () -> println("training"))
```
Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds.
A more typical callback might look like this:
```julia
test_x, test_y = # ... create single batch of test data ...
evalcb() = @show(loss(test_x, test_y))
Flux.train!(loss, data, opt,
cb = throttle(evalcb, 5))
```