Merge pull request #361 from dhairyagandhi96/with_stop

Add stop() to train loop when callback conditions are met
This commit is contained in:
Mike J Innes 2018-08-28 10:56:15 +01:00 committed by GitHub
commit fac06751ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 6 deletions

View File

@ -2,7 +2,7 @@ module Optimise
export train!, export train!,
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
struct Param{T} struct Param{T}
x::T x::T

View File

@ -1,5 +1,6 @@
using Juno using Juno
using Flux.Tracker: back! using Flux.Tracker: back!
import Base.depwarn
runall(f) = f runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs) runall(fs::AbstractVector) = () -> foreach(call, fs)
@ -14,6 +15,25 @@ macro interrupts(ex)
end) end)
end end
struct StopException <: Exception end
"""
stop()
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
This would trigger the train loop to stop and exit.
```julia
# Example callback:
cb = function ()
accuracy() > 0.9 && Flux.stop()
end
```
"""
function stop()
throw(StopException())
end
""" """
train!(loss, data, opt) train!(loss, data, opt)
@ -36,10 +56,21 @@ function train!(loss, data, opt; cb = () -> ())
cb = runall(cb) cb = runall(cb)
opt = runall(opt) opt = runall(opt)
@progress for d in data @progress for d in data
l = loss(d...) try
@interrupts back!(l) l = loss(d...)
opt() @interrupts back!(l)
cb() == :stop && break opt()
if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break
end
catch ex
if ex isa StopException
break
else
rethrow(ex)
end
end
end end
end end

View File

@ -23,7 +23,7 @@ end
Flux.train!(() -> (sleep(0.1); i += 1; l), Flux.train!(() -> (sleep(0.1); i += 1; l),
Iterators.repeated((), 100), Iterators.repeated((), 100),
()->(), ()->(),
cb = Flux.throttle(() -> (i > 3 && :stop), 1)) cb = Flux.throttle(() -> (i > 3 && stop()), 1))
@test 3 < i < 50 @test 3 < i < 50
end end