From 394b4167cedb2a2175721ae49a805734e5c05f74 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 20 Aug 2018 13:43:08 +0530 Subject: [PATCH] moving stop to Optimise --- src/Flux.jl | 4 ++-- src/optimise/Optimise.jl | 2 +- src/optimise/train.jl | 8 +++++++- src/utils.jl | 10 ++++++---- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index c01dbd4e..cd407705 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,7 +7,7 @@ using MacroTools: @forward export Chain, Dense, RNN, LSTM, GRU, Conv, Dropout, LayerNorm, BatchNorm, - params, mapleaves, cpu, gpu, stop, StopException + params, mapleaves, cpu, gpu @reexport using NNlib using NNlib: @fix @@ -21,7 +21,7 @@ include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, - RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM + RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException include("utils.jl") include("onehot.jl") diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index ee7723bc..c4828c9e 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -2,7 +2,7 @@ module Optimise export train!, SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, - RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM + RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException struct Param{T} x::T diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 0a06492c..341e6b43 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,6 +1,5 @@ using Juno using Flux.Tracker: back! -using Flux: stop, StopException runall(f) = f runall(fs::AbstractVector) = () -> foreach(call, fs) @@ -15,6 +14,13 @@ macro interrupts(ex) end) end +struct StopException <: Exception + x::Symbol +end + +function stop(x) + throw(StopException( + """ train!(loss, data, opt) diff --git a/src/utils.jl b/src/utils.jl index c746f391..321e0d94 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -154,8 +154,10 @@ function jacobian(m,x) J' end -struct StopException <: Exception end +# struct StopException <: Exception +# x::Symbol +# end -function stop() - throw(StopException()) -end \ No newline at end of file +# function stop(x) +# throw(StopException(x)) +# end \ No newline at end of file