diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 2a2ec5eb..0809e86b 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,8 +1,8 @@ using Juno using Flux.Tracker: back! -tocb(f) = f -tocb(fs::AbstractVector) = () -> foreach(call, fs) +runall(f) = f +runall(fs::AbstractVector) = () -> foreach(call, fs) """ train!(loss, data, opt; cb = () -> ()) @@ -11,10 +11,11 @@ For each datapoint `d` in `data` computes the gradient of `loss(d...)` through backpropagation and calls the optimizer `opt` and the callback `cb` (i.e. `opt()` and `cb()`). -Multiple callbacks can be passed to `cb` as an array. +Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ function train!(loss, data, opt; cb = () -> ()) - cb = tocb(cb) + cb = runall(cb) + opt = runall(opt) @progress for d in data l = loss(d...) isinf(l.data[]) && error("Loss is Inf")