diff --git a/src/Flux.jl b/src/Flux.jl index e406b5c6..c01dbd4e 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,7 +7,7 @@ using MacroTools: @forward export Chain, Dense, RNN, LSTM, GRU, Conv, Dropout, LayerNorm, BatchNorm, - params, mapleaves, cpu, gpu + params, mapleaves, cpu, gpu, stop, StopException @reexport using NNlib using NNlib: @fix @@ -24,7 +24,6 @@ export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM include("utils.jl") -export stop, StopException include("onehot.jl") include("treelike.jl") @@ -39,3 +38,4 @@ include("data/Data.jl") @init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda/cuda.jl") end # module + diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 3f26fdbd..4b1e205b 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,6 +1,6 @@ using Juno using Flux.Tracker: back! -# include("../utils.jl") +import Flux: stop, StopException runall(f) = f runall(fs::AbstractVector) = () -> foreach(call, fs)