Check for Inf
and NaN
within back!(::TrackedReal)
This is often checked for within user code, no reason to do that, let's do it for them within `back!(::TrackedReal)`
This commit is contained in:
parent
e92f840510
commit
9fdbe843ef
@ -37,8 +37,6 @@ function train!(loss, data, opt; cb = () -> ())
|
||||
opt = runall(opt)
|
||||
@progress for d in data
|
||||
l = loss(d...)
|
||||
isinf(l) && error("Loss is Inf")
|
||||
isnan(l) && error("Loss is NaN")
|
||||
@interrupts back!(l)
|
||||
opt()
|
||||
cb() == :stop && break
|
||||
|
@ -8,7 +8,15 @@ tracker(x::TrackedReal) = x.tracker
|
||||
|
||||
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
|
||||
|
||||
back!(x::TrackedReal) = back!(x, 1)
|
||||
function back!(x::TrackedReal)
|
||||
if isinf(x)
|
||||
error("Loss is Inf")
|
||||
end
|
||||
if isnan(x)
|
||||
error("Loss is NaN")
|
||||
end
|
||||
return back!(x, 1)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, x::TrackedReal)
|
||||
show(io, data(x))
|
||||
|
Loading…
Reference in New Issue
Block a user