pick beta from state in NADAM

This commit is contained in:
Dhairya Gandhi 2019-06-16 19:06:59 +05:30
parent b47238eb74
commit 67f18663d9

View File

@ -214,7 +214,7 @@ NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
function apply!(o::NADAM, x, Δ)
η, β = o.eta, o.beta
β1p, β2p = o.beta
mt, vt = get!(o.state, x, (zero(x), zero(x)))
mt, vt, (β1p, β2p) = get!(o.state, x, (zero(x), zero(x), o.beta))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / ((vt * β[2] / (1 - β2p)) + ϵ) * η