Check for Inf
and NaN
within back!(::TrackedReal)
This is often checked for within user code, no reason to do that, let's do it for them within `back!(::TrackedReal)`
This commit is contained in:
parent
e92f840510
commit
9fdbe843ef
@ -37,8 +37,6 @@ function train!(loss, data, opt; cb = () -> ())
|
|||||||
opt = runall(opt)
|
opt = runall(opt)
|
||||||
@progress for d in data
|
@progress for d in data
|
||||||
l = loss(d...)
|
l = loss(d...)
|
||||||
isinf(l) && error("Loss is Inf")
|
|
||||||
isnan(l) && error("Loss is NaN")
|
|
||||||
@interrupts back!(l)
|
@interrupts back!(l)
|
||||||
opt()
|
opt()
|
||||||
cb() == :stop && break
|
cb() == :stop && break
|
||||||
|
@ -8,7 +8,15 @@ tracker(x::TrackedReal) = x.tracker
|
|||||||
|
|
||||||
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
|
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
|
||||||
|
|
||||||
back!(x::TrackedReal) = back!(x, 1)
|
function back!(x::TrackedReal)
|
||||||
|
if isinf(x)
|
||||||
|
error("Loss is Inf")
|
||||||
|
end
|
||||||
|
if isnan(x)
|
||||||
|
error("Loss is NaN")
|
||||||
|
end
|
||||||
|
return back!(x, 1)
|
||||||
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, x::TrackedReal)
|
function Base.show(io::IO, x::TrackedReal)
|
||||||
show(io, data(x))
|
show(io, data(x))
|
||||||
|
Loading…
Reference in New Issue
Block a user