diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 401a1c51..8ad8573e 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -37,8 +37,6 @@ function train!(loss, data, opt; cb = () -> ()) opt = runall(opt) @progress for d in data l = loss(d...) - isinf(l) && error("Loss is Inf") - isnan(l) && error("Loss is NaN") @interrupts back!(l) opt() cb() == :stop && break diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 8d0aa29e..5b6cfb57 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -8,7 +8,15 @@ tracker(x::TrackedReal) = x.tracker track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x))) -back!(x::TrackedReal) = back!(x, 1) +function back!(x::TrackedReal) + if isinf(x) + error("Loss is Inf") + end + if isnan(x) + error("Loss is NaN") + end + return back!(x, 1) +end function Base.show(io::IO, x::TrackedReal) show(io, data(x))