diff --git a/src/Flux.jl b/src/Flux.jl index cd407705..525b33c4 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -21,7 +21,7 @@ include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, - RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException + RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop include("utils.jl") include("onehot.jl") diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index c4828c9e..ee7723bc 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -2,7 +2,7 @@ module Optimise export train!, SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, - RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException + RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM struct Param{T} x::T diff --git a/src/optimise/train.jl b/src/optimise/train.jl index c84a8191..3ec3eb18 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -55,14 +55,13 @@ function train!(loss, data, opt; cb = () -> ()) cb = runall(cb) opt = runall(opt) @progress for d in data - l = loss(d...) - @interrupts back!(l) - opt() try - cb() + l = loss(d...) + @interrupts back!(l) + opt() + cb() == :stop && break catch ex if ex isa StopException - @info "Stop condition met" break else rethrow(ex)