From ed044e2df78a67c8ce647ac8e09eea831d9462ae Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 21 Aug 2018 23:22:20 +0530 Subject: [PATCH] changes as requested --- src/Flux.jl | 2 +- src/optimise/Optimise.jl | 2 +- src/optimise/train.jl | 9 ++++----- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index cd407705..525b33c4 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -21,7 +21,7 @@ include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, - RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException + RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop include("utils.jl") include("onehot.jl") diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index c4828c9e..ee7723bc 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -2,7 +2,7 @@ module Optimise export train!, SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, - RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException + RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM struct Param{T} x::T diff --git a/src/optimise/train.jl b/src/optimise/train.jl index c84a8191..3ec3eb18 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -55,14 +55,13 @@ function train!(loss, data, opt; cb = () -> ()) cb = runall(cb) opt = runall(opt) @progress for d in data - l = loss(d...) - @interrupts back!(l) - opt() try - cb() + l = loss(d...) + @interrupts back!(l) + opt() + cb() == :stop && break catch ex if ex isa StopException - @info "Stop condition met" break else rethrow(ex)