dataset docs
This commit is contained in:
parent
a05981732f
commit
58f4f1540f
|
@ -32,6 +32,32 @@ Flux.train!(loss, data, opt)
|
|||
|
||||
The loss will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `logloss` for cross entropy loss, but you can calculate it however you want.
|
||||
|
||||
## Datasets
|
||||
|
||||
The `data` argument provides a collection of data to train with (usually a set of inputs `x` and a target outputs `y`). For example, here's a dummy data set with only one data point:
|
||||
|
||||
```julia
|
||||
x = rand(784)
|
||||
y = rand(10)
|
||||
data = [(x, y)]
|
||||
```
|
||||
|
||||
`Flux.train!` will call `loss(x, y)`, calculate gradients, update the weights and then move on to the next data point if there is one. We can train the model on the same data three times:
|
||||
|
||||
```julia
|
||||
data = [(x, y), (x, y), (x, y)]
|
||||
# Or equivalently
|
||||
data = Iterators.repeated((x, y), 3)
|
||||
```
|
||||
|
||||
It's common to load the `x`s and `y`s separately. In this case you can use `zip`:
|
||||
|
||||
```julia
|
||||
xs = [rand(784), rand(784), rand(784)]
|
||||
ys = [rand( 10), rand( 10), rand( 10)]
|
||||
data = zip(xs, ys)
|
||||
```
|
||||
|
||||
## Callbacks
|
||||
|
||||
`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example:
|
||||
|
|
Loading…
Reference in New Issue