diff --git a/docs/src/training/training.md b/docs/src/training/training.md index 47bda1f5..b42db7c9 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -110,3 +110,30 @@ cb = function () accuracy() > 0.9 && Flux.stop() end ``` + +## Custom Training loops + +The `Flux.train!` function can be very convenient, especially for simple problems. +Its also very flexible with the use of callbacks. +But for some problems its much cleaner to write your own custom training loop. +An example follows that works similar to the default `Flux.train` but with no callbacks. +You don't need callbacks if you just code the calls to your functions directly into the loop. +E.g. in the places marked with comments. + +``` +function my_custom_train!(loss, ps, data, opt) + ps = Params(ps) + for d in data + gs = gradient(ps) do + training_loss = loss(d...) + # Insert what ever code you want here that needs Training loss, e.g. logging + return training_loss + end + # insert what ever code you want here that needs gradient + # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge + update!(opt, ps, gs) + # Here you might like to check validation set accuracy, and break out to do early stopping + end +end +``` +You could simplify this further, for example by hard-coding in the loss function.