diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 2accc4bc..d750a848 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -149,7 +149,7 @@ function update!(o::ADAGrad, x, Δ) η = o.eta acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x) @. acc += Δ^2 - @. Δ *= η / √(acc + ϵ) + @. Δ *= η / (√acc + ϵ) end """ @@ -169,7 +169,7 @@ function update!(o::ADADelta, x, Δ) ρ = o.rho acc, Δacc = get!(o.state, x, (zero(x), zero(x))) @. acc = ρ * acc + (1 - ρ) * Δ^2 - @. Δ *= √(Δacc + ϵ) / √(acc + ϵ) + @. Δ *= √Δacc/ (√acc + ϵ) @. Δacc = ρ * Δacc + (1 - ρ) * Δ^2 return Δ end @@ -194,7 +194,7 @@ function update!(o::AMSGrad, x, Δ) @. mt = β[1] * mt + (1 - β[1]) * Δ @. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2 @. v̂t = max.(v̂t, vt) - @. Δ = η * mt / √v̂t + @. Δ = η * mt / (√v̂t + ϵ) end """ @@ -217,7 +217,7 @@ function update!(o::NADAM, x, Δ) mt, vt = get!(o.state, x, (zero(x), zero(x))) @. mt = β[1] * mt + (1 - β[1]) * Δ @. vt = β[2] * vt + (1 - β[2]) * Δ^2 - @. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / √(vt * β[2] / (1 - β2p) + ϵ) * η + @. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / (√(vt * β[2] / (1 - β2p)) + ϵ) * η o.state[x] = (mt, vt, (β1p * β[1], β2p * β[2])) return Δ end