Check for Inf and NaN within back!(::TrackedReal)

This is often checked for within user code, no reason to do that, let's
do it for them within `back!(::TrackedReal)`
This commit is contained in:
staticfloat@gmail.com 2018-05-07 15:30:44 -07:00
parent e92f840510
commit 9fdbe843ef
2 changed files with 9 additions and 3 deletions

View File

@ -37,8 +37,6 @@ function train!(loss, data, opt; cb = () -> ())
opt = runall(opt)
@progress for d in data
l = loss(d...)
isinf(l) && error("Loss is Inf")
isnan(l) && error("Loss is NaN")
@interrupts back!(l)
opt()
cb() == :stop && break

View File

@ -8,7 +8,15 @@ tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
back!(x::TrackedReal) = back!(x, 1)
function back!(x::TrackedReal)
if isinf(x)
error("Loss is Inf")
end
if isnan(x)
error("Loss is NaN")
end
return back!(x, 1)
end
function Base.show(io::IO, x::TrackedReal)
show(io, data(x))