Merge pull request #448 from JobJob/adam-match-paper

Match paper for Adam implementation and make epsilon use more consistent
This commit is contained in:
Mike J Innes 2018-11-05 12:57:30 +00:00 committed by GitHub
commit d071014fae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -149,7 +149,7 @@ function update!(o::ADAGrad, x, Δ)
η = o.eta η = o.eta
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x) acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
@. acc += Δ^2 @. acc += Δ^2
@. Δ *= η / (acc + ϵ) @. Δ *= η / (acc + ϵ)
end end
""" """
@ -169,7 +169,7 @@ function update!(o::ADADelta, x, Δ)
ρ = o.rho ρ = o.rho
acc, Δacc = get!(o.state, x, (zero(x), zero(x))) acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
@. acc = ρ * acc + (1 - ρ) * Δ^2 @. acc = ρ * acc + (1 - ρ) * Δ^2
@. Δ *= (Δacc + ϵ) / (acc + ϵ) @. Δ *= Δacc/ (acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * Δ^2 @. Δacc = ρ * Δacc + (1 - ρ) * Δ^2
return Δ return Δ
end end
@ -194,7 +194,7 @@ function update!(o::AMSGrad, x, Δ)
@. mt = β[1] * mt + (1 - β[1]) * Δ @. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2 @. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2
@. v̂t = max.(v̂t, vt) @. v̂t = max.(v̂t, vt)
@. Δ = η * mt / v̂t @. Δ = η * mt / (v̂t + ϵ)
end end
""" """
@ -217,7 +217,7 @@ function update!(o::NADAM, x, Δ)
mt, vt = get!(o.state, x, (zero(x), zero(x))) mt, vt = get!(o.state, x, (zero(x), zero(x)))
@. mt = β[1] * mt + (1 - β[1]) * Δ @. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ^2 @. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / (vt * β[2] / (1 - β2p) + ϵ) * η @. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / ((vt * β[2] / (1 - β2p)) + ϵ) * η
o.state[x] = (mt, vt, (β1p * β[1], β2p * β[2])) o.state[x] = (mt, vt, (β1p * β[1], β2p * β[2]))
return Δ return Δ
end end