diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 2cc20268..1928a80d 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -34,18 +34,22 @@ The callback can return `:stop` to interrupt the training loop. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ function train!(loss, data, opt; cb = () -> ()) - cb = try - runall(cb) - catch ex - if ex isa StopException || rethrow(ex) - @info "Stop Condition Met" - return :stop + cb = runall(cb) opt = runall(opt) @progress for d in data l = loss(d...) @interrupts back!(l) opt() - cb() == :stop && break + try + cb() + catch ex + if ex isa StopException + @info "Stop condition met" + break + else + rethrow(ex) + end + end end