diff --git a/src/Flux.jl b/src/Flux.jl index d8db39e9..fb52e859 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -19,7 +19,7 @@ export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs -export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, +export Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay, ClipValue, ClipNorm diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 66be6dce..4ea5235e 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -2,10 +2,10 @@ module Optimise using LinearAlgebra -export train!, update!, - SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, - ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, - InvDecay, ExpDecay, WeightDecay, stop, Optimiser, ClipValue, ClipNorm +export train!, update!, stop, Optimiser, + Descent, ADAM, Momentum, Nesterov, RMSProp, + ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, + InvDecay, ExpDecay, WeightDecay, ClipValue, ClipNorm include("optimisers.jl") include("train.jl") diff --git a/src/utils.jl b/src/utils.jl index c666caca..7842c961 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -24,7 +24,7 @@ glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum glorot_normal(dims...) Return an `Array` of size `dims` containing random variables taken from a normal -distribution with mean 0 and standard deviation `(2 / sum(dims))`. +distribution with mean 0 and standard deviation `sqrt(2 / sum(dims))`. # Examples ```jldoctest; setup = :(using Random; Random.seed!(0))