commit
9345607c38
@ -23,7 +23,7 @@ include("optimise/Optimise.jl")
|
|||||||
using .Optimise
|
using .Optimise
|
||||||
using .Optimise: @epochs
|
using .Optimise: @epochs
|
||||||
export SGD, ADAM, AdaMax, Momentum, Nesterov,
|
export SGD, ADAM, AdaMax, Momentum, Nesterov,
|
||||||
RMSProp, ADAGrad, ADADelta, AMSGrad
|
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("onehot.jl")
|
include("onehot.jl")
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
module Optimise
|
module Optimise
|
||||||
|
|
||||||
export train!,
|
export train!,
|
||||||
SGD, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
|
SGD, ADAM, AdaMax, Momentum, Nesterov,
|
||||||
|
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
||||||
|
|
||||||
struct Param{T}
|
struct Param{T}
|
||||||
x::T
|
x::T
|
||||||
|
@ -91,3 +91,12 @@ tuning.
|
|||||||
"""
|
"""
|
||||||
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||||
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||||
|
|
||||||
|
"""
|
||||||
|
NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||||
|
|
||||||
|
[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
|
||||||
|
than learning rate don't need tuning.
|
||||||
|
"""
|
||||||
|
NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||||
|
optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
||||||
|
@ -27,7 +27,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
|||||||
acc = zeros(p.x)
|
acc = zeros(p.x)
|
||||||
function ()
|
function ()
|
||||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
||||||
@. p.Δ *= η / (√acc + ϵ)
|
@. p.Δ *= η / √(acc + ϵ)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
|||||||
acc = zeros(p.x) .+ ϵ
|
acc = zeros(p.x) .+ ϵ
|
||||||
function ()
|
function ()
|
||||||
@. acc += p.Δ^2
|
@. acc += p.Δ^2
|
||||||
@. p.Δ *= η / √acc
|
@. p.Δ *= η / √(acc + ϵ)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -56,7 +56,7 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
|
|||||||
function ()
|
function ()
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||||
@. vt = β2 * vt + (1 - β2) * p.Δ^2
|
@. vt = β2 * vt + (1 - β2) * p.Δ^2
|
||||||
@. p.Δ = mt / (1 - β1p) / (√(vt / (1 - β2p)) + ϵ) * η
|
@. p.Δ = mt / (1 - β1p) / √(vt / (1 - β2p) + ϵ) * η
|
||||||
β1p *= β1
|
β1p *= β1
|
||||||
β2p *= β2
|
β2p *= β2
|
||||||
end
|
end
|
||||||
@ -86,6 +86,19 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999,
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||||
|
mt = zeros(p.x)
|
||||||
|
vt = zeros(p.x)
|
||||||
|
β1p, β2p = β1, β2
|
||||||
|
function ()
|
||||||
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||||
|
@. vt = β2 * vt + (1 - β2) * p.Δ^2
|
||||||
|
@. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / √(vt * β2 / (1 - β2p) + ϵ) * η
|
||||||
|
β1p *= β1
|
||||||
|
β2p *= β2
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
|
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
|
||||||
|
|
||||||
function expdecay(p::Param, γ::Real)
|
function expdecay(p::Param, γ::Real)
|
||||||
|
@ -3,7 +3,7 @@ using Flux.Tracker
|
|||||||
|
|
||||||
@testset "Optimise" begin
|
@testset "Optimise" begin
|
||||||
w = randn(10, 10)
|
w = randn(10, 10)
|
||||||
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
|
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
|
||||||
w′ = param(randn(10, 10))
|
w′ = param(randn(10, 10))
|
||||||
loss(x) = Flux.mse(w*x, w′*x)
|
loss(x) = Flux.mse(w*x, w′*x)
|
||||||
opt = Opt([w′])
|
opt = Opt([w′])
|
||||||
|
Loading…
Reference in New Issue
Block a user