This commit is contained in:
Mike Innes 2018-10-05 12:43:03 +01:00
parent 4abe518599
commit 9bc9771a8d
3 changed files with 6 additions and 6 deletions

View File

@ -3,7 +3,7 @@ module Optimise
export train!,
Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
InvDecay, ExpDecay, stop, StopException, Compose
InvDecay, ExpDecay, stop, Compose
include("optimisers.jl")
include("train.jl")

View File

@ -16,7 +16,7 @@ mutable struct Descent
eta::Float64
end
Descent(η = 0.1) = Descent(η)
Descent() = Descent(0.1)
function update!(o::Descent, x, Δ)
Δ .*= o.eta
end
@ -275,7 +275,7 @@ mutable struct InvDecay
n::Int64
end
InvDecay(γ = 0.001, n = 0) = InvDecay(γ, n)
InvDecay(γ = 0.001) = InvDecay(γ, 0)
function update!(o::InvDecay, x, Δ)
γ, n = o.gamma, o.n
@ -288,7 +288,7 @@ mutable struct ExpDecay
gamma::Float64
end
ExpDecay(γ = 0.001) = ExpDecay(γ)
ExpDecay() = ExpDecay(0.001)
function update!(o::ExpDecay, x, Δ)
γ = o.gamma
@ -300,7 +300,7 @@ mutable struct DescentWeightDecay
gamma::Real
end
DescentWeightDecay(η = 1, γ = 0) = DescentWeightDecay(η, γ)
DescentWeightDecay(η = 1) = DescentWeightDecay(η, 0)
function update!(o::DescentWeightDecay, x, Δ)
η, γ = o.eta, o.gamma
@. x = x - η * (Δ + γ * x)

View File

@ -5,7 +5,7 @@ import Base.depwarn
function update!(opt, xs)
for x in xs
x, Δ = data(x), grad(x)
update!(opt, x, Δ)
Δ = update!(opt, x, Δ)
x .-= Δ
Δ .= 0
end