From 9fdbe843eff2eb091ccc25c4dfcb74ee6f258eb8 Mon Sep 17 00:00:00 2001 From: "staticfloat@gmail.com" Date: Mon, 7 May 2018 15:30:44 -0700 Subject: [PATCH] Check for `Inf` and `NaN` within `back!(::TrackedReal)` This is often checked for within user code, no reason to do that, let's do it for them within `back!(::TrackedReal)` --- src/optimise/train.jl | 2 -- src/tracker/scalar.jl | 10 +++++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 401a1c51..8ad8573e 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -37,8 +37,6 @@ function train!(loss, data, opt; cb = () -> ()) opt = runall(opt) @progress for d in data l = loss(d...) - isinf(l) && error("Loss is Inf") - isnan(l) && error("Loss is NaN") @interrupts back!(l) opt() cb() == :stop && break diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 632046cd..17c35513 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -8,7 +8,15 @@ tracker(x::TrackedReal) = x.tracker track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x))) -back!(x::TrackedReal) = back!(x, 1) +function back!(x::TrackedReal) + if isinf(x) + error("Loss is Inf") + end + if isnan(x) + error("Loss is NaN") + end + return back!(x, 1) +end function Base.show(io::IO, x::TrackedReal) show(io, data(x))