Merge pull request #361 from dhairyagandhi96/with_stop
Add stop() to train loop when callback conditions are met
This commit is contained in:
commit
fac06751ea
@ -2,7 +2,7 @@ module Optimise
|
|||||||
|
|
||||||
export train!,
|
export train!,
|
||||||
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
||||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
|
||||||
|
|
||||||
struct Param{T}
|
struct Param{T}
|
||||||
x::T
|
x::T
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
using Juno
|
using Juno
|
||||||
using Flux.Tracker: back!
|
using Flux.Tracker: back!
|
||||||
|
import Base.depwarn
|
||||||
|
|
||||||
runall(f) = f
|
runall(f) = f
|
||||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||||
@ -14,6 +15,25 @@ macro interrupts(ex)
|
|||||||
end)
|
end)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
struct StopException <: Exception end
|
||||||
|
"""
|
||||||
|
stop()
|
||||||
|
|
||||||
|
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
|
||||||
|
This would trigger the train loop to stop and exit.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
# Example callback:
|
||||||
|
|
||||||
|
cb = function ()
|
||||||
|
accuracy() > 0.9 && Flux.stop()
|
||||||
|
end
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
function stop()
|
||||||
|
throw(StopException())
|
||||||
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
train!(loss, data, opt)
|
train!(loss, data, opt)
|
||||||
|
|
||||||
@ -36,10 +56,21 @@ function train!(loss, data, opt; cb = () -> ())
|
|||||||
cb = runall(cb)
|
cb = runall(cb)
|
||||||
opt = runall(opt)
|
opt = runall(opt)
|
||||||
@progress for d in data
|
@progress for d in data
|
||||||
|
try
|
||||||
l = loss(d...)
|
l = loss(d...)
|
||||||
@interrupts back!(l)
|
@interrupts back!(l)
|
||||||
opt()
|
opt()
|
||||||
cb() == :stop && break
|
if cb() == :stop
|
||||||
|
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||||
|
break
|
||||||
|
end
|
||||||
|
catch ex
|
||||||
|
if ex isa StopException
|
||||||
|
break
|
||||||
|
else
|
||||||
|
rethrow(ex)
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ end
|
|||||||
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||||
Iterators.repeated((), 100),
|
Iterators.repeated((), 100),
|
||||||
()->(),
|
()->(),
|
||||||
cb = Flux.throttle(() -> (i > 3 && :stop), 1))
|
cb = Flux.throttle(() -> (i > 3 && stop()), 1))
|
||||||
|
|
||||||
@test 3 < i < 50
|
@test 3 < i < 50
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user