training docs
This commit is contained in:
parent
726a8acefe
commit
c2db42d38e
|
@ -1,4 +1,47 @@
|
|||
To actually train a model we need three things:
|
||||
|
||||
* A *loss function*, that evaluates how well a model is doing given some input data.
|
||||
* A collection of data points that will be provided to the loss function.
|
||||
* An [optimiser](./optimisers.html) that will update the model parameters appropriately.
|
||||
|
||||
With these we can call `Flux.train!`:
|
||||
|
||||
```julia
|
||||
Flux.train!(loss, repeated((x,y), 1000), SGD(params(m), 0.1),
|
||||
cb = throttle(() -> @show(loss(x, y)), 5))
|
||||
Flux.train!(loss, data, opt)
|
||||
```
|
||||
|
||||
There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo).
|
||||
|
||||
## Loss Functions
|
||||
|
||||
The `loss` that we defined in [basics](../models/basics.html) is completely valid for training. We can also define a loss in terms of some model:
|
||||
|
||||
```julia
|
||||
m = Chain(
|
||||
Dense(784, 32, σ),
|
||||
Dense(32, 10), softmax)
|
||||
|
||||
loss(x, y) = Flux.mse(m(x), y)
|
||||
```
|
||||
|
||||
The loss will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `logloss` for cross entropy loss, but you can calculate it however you want.
|
||||
|
||||
## Callbacks
|
||||
|
||||
`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example:
|
||||
|
||||
```julia
|
||||
train!(loss, data, opt, cb = () -> println("training"))
|
||||
```
|
||||
|
||||
Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds.
|
||||
|
||||
A more typical callback might look like this:
|
||||
|
||||
```julia
|
||||
test_x, test_y = # ... create single batch of test data ...
|
||||
evalcb() = @show(loss(test_x, test_y))
|
||||
|
||||
Flux.train!(loss, data, opt,
|
||||
cb = throttle(evalcb, 5))
|
||||
```
|
||||
|
|
Loading…
Reference in New Issue