tweaks
This commit is contained in:
parent
4abe518599
commit
9bc9771a8d
@ -3,7 +3,7 @@ module Optimise
|
|||||||
export train!,
|
export train!,
|
||||||
Descent, ADAM, Momentum, Nesterov, RMSProp,
|
Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
|
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
|
||||||
InvDecay, ExpDecay, stop, StopException, Compose
|
InvDecay, ExpDecay, stop, Compose
|
||||||
|
|
||||||
include("optimisers.jl")
|
include("optimisers.jl")
|
||||||
include("train.jl")
|
include("train.jl")
|
||||||
|
@ -16,7 +16,7 @@ mutable struct Descent
|
|||||||
eta::Float64
|
eta::Float64
|
||||||
end
|
end
|
||||||
|
|
||||||
Descent(η = 0.1) = Descent(η)
|
Descent() = Descent(0.1)
|
||||||
function update!(o::Descent, x, Δ)
|
function update!(o::Descent, x, Δ)
|
||||||
Δ .*= o.eta
|
Δ .*= o.eta
|
||||||
end
|
end
|
||||||
@ -275,7 +275,7 @@ mutable struct InvDecay
|
|||||||
n::Int64
|
n::Int64
|
||||||
end
|
end
|
||||||
|
|
||||||
InvDecay(γ = 0.001, n = 0) = InvDecay(γ, n)
|
InvDecay(γ = 0.001) = InvDecay(γ, 0)
|
||||||
|
|
||||||
function update!(o::InvDecay, x, Δ)
|
function update!(o::InvDecay, x, Δ)
|
||||||
γ, n = o.gamma, o.n
|
γ, n = o.gamma, o.n
|
||||||
@ -288,7 +288,7 @@ mutable struct ExpDecay
|
|||||||
gamma::Float64
|
gamma::Float64
|
||||||
end
|
end
|
||||||
|
|
||||||
ExpDecay(γ = 0.001) = ExpDecay(γ)
|
ExpDecay() = ExpDecay(0.001)
|
||||||
|
|
||||||
function update!(o::ExpDecay, x, Δ)
|
function update!(o::ExpDecay, x, Δ)
|
||||||
γ = o.gamma
|
γ = o.gamma
|
||||||
@ -300,7 +300,7 @@ mutable struct DescentWeightDecay
|
|||||||
gamma::Real
|
gamma::Real
|
||||||
end
|
end
|
||||||
|
|
||||||
DescentWeightDecay(η = 1, γ = 0) = DescentWeightDecay(η, γ)
|
DescentWeightDecay(η = 1) = DescentWeightDecay(η, 0)
|
||||||
function update!(o::DescentWeightDecay, x, Δ)
|
function update!(o::DescentWeightDecay, x, Δ)
|
||||||
η, γ = o.eta, o.gamma
|
η, γ = o.eta, o.gamma
|
||||||
@. x = x - η * (Δ + γ * x)
|
@. x = x - η * (Δ + γ * x)
|
||||||
|
@ -5,7 +5,7 @@ import Base.depwarn
|
|||||||
function update!(opt, xs)
|
function update!(opt, xs)
|
||||||
for x in xs
|
for x in xs
|
||||||
x, Δ = data(x), grad(x)
|
x, Δ = data(x), grad(x)
|
||||||
update!(opt, x, Δ)
|
Δ = update!(opt, x, Δ)
|
||||||
x .-= Δ
|
x .-= Δ
|
||||||
Δ .= 0
|
Δ .= 0
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user