cleaner interrupts

This commit is contained in:
Mike Innes 2018-03-05 23:05:45 +00:00
parent 5153cde847
commit bfd6a4c0ec

View File

@ -4,6 +4,16 @@ using Flux.Tracker: back!
runall(f) = f runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs) runall(fs::AbstractVector) = () -> foreach(call, fs)
# The AD generates fairly large backtraces that are unhelpful if you interrupt
# while training; this just cleans that up.
macro interrupts(ex)
:(try $(esc(ex))
catch e
e isa InterruptException || rethrow()
throw(e)
end)
end
""" """
train!(loss, data, opt) train!(loss, data, opt)
@ -29,7 +39,7 @@ function train!(loss, data, opt; cb = () -> ())
l = loss(d...) l = loss(d...)
isinf(l) && error("Loss is Inf") isinf(l) && error("Loss is Inf")
isnan(l) && error("Loss is NaN") isnan(l) && error("Loss is NaN")
back!(l) @interrupts back!(l)
opt() opt()
cb() == :stop && break cb() == :stop && break
end end