tweaks
This commit is contained in:
parent
4abe518599
commit
9bc9771a8d
@ -3,7 +3,7 @@ module Optimise
|
||||
export train!,
|
||||
Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
|
||||
InvDecay, ExpDecay, stop, StopException, Compose
|
||||
InvDecay, ExpDecay, stop, Compose
|
||||
|
||||
include("optimisers.jl")
|
||||
include("train.jl")
|
||||
|
@ -16,7 +16,7 @@ mutable struct Descent
|
||||
eta::Float64
|
||||
end
|
||||
|
||||
Descent(η = 0.1) = Descent(η)
|
||||
Descent() = Descent(0.1)
|
||||
function update!(o::Descent, x, Δ)
|
||||
Δ .*= o.eta
|
||||
end
|
||||
@ -275,7 +275,7 @@ mutable struct InvDecay
|
||||
n::Int64
|
||||
end
|
||||
|
||||
InvDecay(γ = 0.001, n = 0) = InvDecay(γ, n)
|
||||
InvDecay(γ = 0.001) = InvDecay(γ, 0)
|
||||
|
||||
function update!(o::InvDecay, x, Δ)
|
||||
γ, n = o.gamma, o.n
|
||||
@ -288,7 +288,7 @@ mutable struct ExpDecay
|
||||
gamma::Float64
|
||||
end
|
||||
|
||||
ExpDecay(γ = 0.001) = ExpDecay(γ)
|
||||
ExpDecay() = ExpDecay(0.001)
|
||||
|
||||
function update!(o::ExpDecay, x, Δ)
|
||||
γ = o.gamma
|
||||
@ -300,7 +300,7 @@ mutable struct DescentWeightDecay
|
||||
gamma::Real
|
||||
end
|
||||
|
||||
DescentWeightDecay(η = 1, γ = 0) = DescentWeightDecay(η, γ)
|
||||
DescentWeightDecay(η = 1) = DescentWeightDecay(η, 0)
|
||||
function update!(o::DescentWeightDecay, x, Δ)
|
||||
η, γ = o.eta, o.gamma
|
||||
@. x = x - η * (Δ + γ * x)
|
||||
|
@ -5,7 +5,7 @@ import Base.depwarn
|
||||
function update!(opt, xs)
|
||||
for x in xs
|
||||
x, Δ = data(x), grad(x)
|
||||
update!(opt, x, Δ)
|
||||
Δ = update!(opt, x, Δ)
|
||||
x .-= Δ
|
||||
Δ .= 0
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user