This commit is contained in:
Mike Innes 2018-10-05 12:43:03 +01:00
parent 4abe518599
commit 9bc9771a8d
3 changed files with 6 additions and 6 deletions

View File

@ -3,7 +3,7 @@ module Optimise
export train!, export train!,
Descent, ADAM, Momentum, Nesterov, RMSProp, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
InvDecay, ExpDecay, stop, StopException, Compose InvDecay, ExpDecay, stop, Compose
include("optimisers.jl") include("optimisers.jl")
include("train.jl") include("train.jl")

View File

@ -16,7 +16,7 @@ mutable struct Descent
eta::Float64 eta::Float64
end end
Descent(η = 0.1) = Descent(η) Descent() = Descent(0.1)
function update!(o::Descent, x, Δ) function update!(o::Descent, x, Δ)
Δ .*= o.eta Δ .*= o.eta
end end
@ -275,7 +275,7 @@ mutable struct InvDecay
n::Int64 n::Int64
end end
InvDecay(γ = 0.001, n = 0) = InvDecay(γ, n) InvDecay(γ = 0.001) = InvDecay(γ, 0)
function update!(o::InvDecay, x, Δ) function update!(o::InvDecay, x, Δ)
γ, n = o.gamma, o.n γ, n = o.gamma, o.n
@ -288,7 +288,7 @@ mutable struct ExpDecay
gamma::Float64 gamma::Float64
end end
ExpDecay(γ = 0.001) = ExpDecay(γ) ExpDecay() = ExpDecay(0.001)
function update!(o::ExpDecay, x, Δ) function update!(o::ExpDecay, x, Δ)
γ = o.gamma γ = o.gamma
@ -300,7 +300,7 @@ mutable struct DescentWeightDecay
gamma::Real gamma::Real
end end
DescentWeightDecay(η = 1, γ = 0) = DescentWeightDecay(η, γ) DescentWeightDecay(η = 1) = DescentWeightDecay(η, 0)
function update!(o::DescentWeightDecay, x, Δ) function update!(o::DescentWeightDecay, x, Δ)
η, γ = o.eta, o.gamma η, γ = o.eta, o.gamma
@. x = x - η * (Δ + γ * x) @. x = x - η * (Δ + γ * x)

View File

@ -5,7 +5,7 @@ import Base.depwarn
function update!(opt, xs) function update!(opt, xs)
for x in xs for x in xs
x, Δ = data(x), grad(x) x, Δ = data(x), grad(x)
update!(opt, x, Δ) Δ = update!(opt, x, Δ)
x .-= Δ x .-= Δ
Δ .= 0 Δ .= 0
end end