diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 6317b3ec..07577e94 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -67,6 +67,7 @@ function train!(loss, ps, data, opt; cb = () -> ()) loss(d...) end update!(opt, ps, gs) + cb() catch ex if ex isa StopException break