Merge pull request #994 from FluxML/ox/doccustomtraining
Add custom training loops to docs
This commit is contained in:
commit
ddc2c20e68
|
@ -110,3 +110,30 @@ cb = function ()
|
|||
accuracy() > 0.9 && Flux.stop()
|
||||
end
|
||||
```
|
||||
|
||||
## Custom Training loops
|
||||
|
||||
The `Flux.train!` function can be very convenient, especially for simple problems.
|
||||
Its also very flexible with the use of callbacks.
|
||||
But for some problems its much cleaner to write your own custom training loop.
|
||||
An example follows that works similar to the default `Flux.train` but with no callbacks.
|
||||
You don't need callbacks if you just code the calls to your functions directly into the loop.
|
||||
E.g. in the places marked with comments.
|
||||
|
||||
```
|
||||
function my_custom_train!(loss, ps, data, opt)
|
||||
ps = Params(ps)
|
||||
for d in data
|
||||
gs = gradient(ps) do
|
||||
training_loss = loss(d...)
|
||||
# Insert what ever code you want here that needs Training loss, e.g. logging
|
||||
return training_loss
|
||||
end
|
||||
# insert what ever code you want here that needs gradient
|
||||
# E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge
|
||||
update!(opt, ps, gs)
|
||||
# Here you might like to check validation set accuracy, and break out to do early stopping
|
||||
end
|
||||
end
|
||||
```
|
||||
You could simplify this further, for example by hard-coding in the loss function.
|
||||
|
|
Loading…
Reference in New Issue