diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 569e69aa..29b058ba 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -64,11 +64,11 @@ end function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) mt = zeros(p.x) - ut = zero(p.x) + ut = zeros(p.x) β1p = β1 function () @. mt = β1 * mt + (1 - β1) * p.Δ - ut = max(β2 * ut, norm(p.Δ, Inf)) + @. ut = max(β2 * ut, abs(p.Δ)) @. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ) β1p *= β1 end