cleaner interrupts
This commit is contained in:
parent
5153cde847
commit
bfd6a4c0ec
@ -4,6 +4,16 @@ using Flux.Tracker: back!
|
|||||||
runall(f) = f
|
runall(f) = f
|
||||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||||
|
|
||||||
|
# The AD generates fairly large backtraces that are unhelpful if you interrupt
|
||||||
|
# while training; this just cleans that up.
|
||||||
|
macro interrupts(ex)
|
||||||
|
:(try $(esc(ex))
|
||||||
|
catch e
|
||||||
|
e isa InterruptException || rethrow()
|
||||||
|
throw(e)
|
||||||
|
end)
|
||||||
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
train!(loss, data, opt)
|
train!(loss, data, opt)
|
||||||
|
|
||||||
@ -29,7 +39,7 @@ function train!(loss, data, opt; cb = () -> ())
|
|||||||
l = loss(d...)
|
l = loss(d...)
|
||||||
isinf(l) && error("Loss is Inf")
|
isinf(l) && error("Loss is Inf")
|
||||||
isnan(l) && error("Loss is NaN")
|
isnan(l) && error("Loss is NaN")
|
||||||
back!(l)
|
@interrupts back!(l)
|
||||||
opt()
|
opt()
|
||||||
cb() == :stop && break
|
cb() == :stop && break
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user