diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index bc1f9805..31d47c32 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -81,7 +81,7 @@ function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, function () @. mt = β1 * mt + (1 - β1) * p.Δ @. vt = β2 * vt + (1 - β2) * p.Δ^2 - @. p.Δ = (β1 * mt + (1 - β1) * p.Δ) / (1 - β1p) / (√(vt / (1 - β2p)) + ϵ) * η + @. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / √(vt * β2 / (1 - β2p) + ϵ) * η β1p *= β1 β2p *= β2 end