diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index aa2db1c5..2319cfdb 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -213,8 +213,7 @@ NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict()) function apply!(o::NADAM, x, Δ) η, β = o.eta, o.beta - β1p, β2p = o.beta - mt, vt = get!(o.state, x, (zero(x), zero(x))) + mt, vt, (β1p, β2p) = get!(o.state, x, (zero(x), zero(x), o.beta)) @. mt = β[1] * mt + (1 - β[1]) * Δ @. vt = β[2] * vt + (1 - β[2]) * Δ^2 @. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / (√(vt * β[2] / (1 - β2p)) + ϵ) * η