diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 6687f268..401a1c51 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -4,6 +4,16 @@ using Flux.Tracker: back! runall(f) = f runall(fs::AbstractVector) = () -> foreach(call, fs) +# The AD generates fairly large backtraces that are unhelpful if you interrupt +# while training; this just cleans that up. +macro interrupts(ex) + :(try $(esc(ex)) + catch e + e isa InterruptException || rethrow() + throw(e) + end) +end + """ train!(loss, data, opt) @@ -29,7 +39,7 @@ function train!(loss, data, opt; cb = () -> ()) l = loss(d...) isinf(l) && error("Loss is Inf") isnan(l) && error("Loss is NaN") - back!(l) + @interrupts back!(l) opt() cb() == :stop && break end