diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 618ecf66..2a2ec5eb 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -4,10 +4,19 @@ using Flux.Tracker: back! tocb(f) = f tocb(fs::AbstractVector) = () -> foreach(call, fs) +""" + train!(loss, data, opt; cb = () -> ()) + +For each datapoint `d` in `data` computes the gradient of `loss(d...)` through +backpropagation and calls the optimizer `opt` and the callback `cb` +(i.e. `opt()` and `cb()`). + +Multiple callbacks can be passed to `cb` as an array. +""" function train!(loss, data, opt; cb = () -> ()) cb = tocb(cb) - @progress for x in data - l = loss(x...) + @progress for d in data + l = loss(d...) isinf(l.data[]) && error("Loss is Inf") isnan(l.data[]) && error("Loss is NaN") back!(l)