Merge pull request #361 from dhairyagandhi96/with_stop

Add stop() to train loop when callback conditions are met
This commit is contained in:
Mike J Innes 2018-08-28 10:56:15 +01:00 committed by GitHub
commit fac06751ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 6 deletions

View File

@ -2,7 +2,7 @@ module Optimise
export train!,
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
struct Param{T}
x::T

View File

@ -1,5 +1,6 @@
using Juno
using Flux.Tracker: back!
import Base.depwarn
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
@ -14,6 +15,25 @@ macro interrupts(ex)
end)
end
struct StopException <: Exception end
"""
stop()
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
This would trigger the train loop to stop and exit.
```julia
# Example callback:
cb = function ()
accuracy() > 0.9 && Flux.stop()
end
```
"""
function stop()
throw(StopException())
end
"""
train!(loss, data, opt)
@ -36,10 +56,21 @@ function train!(loss, data, opt; cb = () -> ())
cb = runall(cb)
opt = runall(opt)
@progress for d in data
l = loss(d...)
@interrupts back!(l)
opt()
cb() == :stop && break
try
l = loss(d...)
@interrupts back!(l)
opt()
if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break
end
catch ex
if ex isa StopException
break
else
rethrow(ex)
end
end
end
end

View File

@ -23,7 +23,7 @@ end
Flux.train!(() -> (sleep(0.1); i += 1; l),
Iterators.repeated((), 100),
()->(),
cb = Flux.throttle(() -> (i > 3 && :stop), 1))
cb = Flux.throttle(() -> (i > 3 && stop()), 1))
@test 3 < i < 50
end