diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 54b7f53a..79ebcc06 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -61,7 +61,7 @@ end For each datapoint `d` in `data` computes the gradient of `loss(d...)` through backpropagation and calls the optimizer `opt`. -In case datapoints `d` are of array type, assumes no splatting is needed +In case datapoints `d` are of numeric array type, assumes no splatting is needed and computes the gradient of `loss(d)`. Takes a callback as keyword argument `cb`. For example, this will print "training"