Add custom training loops to docs
This commit is contained in:
parent
d1edd9b16d
commit
7797e31b44
|
@ -110,3 +110,30 @@ cb = function ()
|
||||||
accuracy() > 0.9 && Flux.stop()
|
accuracy() > 0.9 && Flux.stop()
|
||||||
end
|
end
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Custom Training loops
|
||||||
|
|
||||||
|
The `Flux.train!` function can be very convenient, especially for simple problems.
|
||||||
|
Its also very flexible with the use of callbacks.
|
||||||
|
But for some problems its much cleaner to write your own custom training loop.
|
||||||
|
An example follows that works similar to the default `Flux.train` but with no callbacks.
|
||||||
|
You don't need callbacks if you just code the calls to your functions directly into the loop.
|
||||||
|
E.g. in the places marked with comments.
|
||||||
|
|
||||||
|
```
|
||||||
|
function my_custom_train!(loss, ps, data, opt)
|
||||||
|
ps = Params(ps)
|
||||||
|
for d in data
|
||||||
|
gs = gradient(ps) do
|
||||||
|
training_loss = loss(d...)
|
||||||
|
# Insert what ever code you want here that needs Training loss, e.g. logging
|
||||||
|
return training_loss
|
||||||
|
end
|
||||||
|
# insert what ever code you want here that needs gradient
|
||||||
|
# E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge
|
||||||
|
update!(opt, ps, gs)
|
||||||
|
# Here you might like to check validation set accuracy, and break out to do early stopping
|
||||||
|
end
|
||||||
|
end
|
||||||
|
```
|
||||||
|
You could simplify this further, for example by hard-coding in the loss function.
|
||||||
|
|
Loading…
Reference in New Issue