Merge pull request #361 from dhairyagandhi96/with_stop
Add stop() to train loop when callback conditions are met
This commit is contained in:
commit
fac06751ea
|
@ -2,7 +2,7 @@ module Optimise
|
|||
|
||||
export train!,
|
||||
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
|
||||
|
||||
struct Param{T}
|
||||
x::T
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
using Juno
|
||||
using Flux.Tracker: back!
|
||||
import Base.depwarn
|
||||
|
||||
runall(f) = f
|
||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||
|
@ -14,6 +15,25 @@ macro interrupts(ex)
|
|||
end)
|
||||
end
|
||||
|
||||
struct StopException <: Exception end
|
||||
"""
|
||||
stop()
|
||||
|
||||
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
|
||||
This would trigger the train loop to stop and exit.
|
||||
|
||||
```julia
|
||||
# Example callback:
|
||||
|
||||
cb = function ()
|
||||
accuracy() > 0.9 && Flux.stop()
|
||||
end
|
||||
```
|
||||
"""
|
||||
function stop()
|
||||
throw(StopException())
|
||||
end
|
||||
|
||||
"""
|
||||
train!(loss, data, opt)
|
||||
|
||||
|
@ -36,10 +56,21 @@ function train!(loss, data, opt; cb = () -> ())
|
|||
cb = runall(cb)
|
||||
opt = runall(opt)
|
||||
@progress for d in data
|
||||
l = loss(d...)
|
||||
@interrupts back!(l)
|
||||
opt()
|
||||
cb() == :stop && break
|
||||
try
|
||||
l = loss(d...)
|
||||
@interrupts back!(l)
|
||||
opt()
|
||||
if cb() == :stop
|
||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||
break
|
||||
end
|
||||
catch ex
|
||||
if ex isa StopException
|
||||
break
|
||||
else
|
||||
rethrow(ex)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ end
|
|||
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||
Iterators.repeated((), 100),
|
||||
()->(),
|
||||
cb = Flux.throttle(() -> (i > 3 && :stop), 1))
|
||||
cb = Flux.throttle(() -> (i > 3 && stop()), 1))
|
||||
|
||||
@test 3 < i < 50
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue