diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 873a3ece..4c5c8290 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -3,7 +3,7 @@ module Optimise export train!, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, - InvDecay, ExpDecay, stop, StopException, Compose + InvDecay, ExpDecay, stop, Compose include("optimisers.jl") include("train.jl") diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 4005db4f..ae30445a 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -16,7 +16,7 @@ mutable struct Descent eta::Float64 end -Descent(η = 0.1) = Descent(η) +Descent() = Descent(0.1) function update!(o::Descent, x, Δ) Δ .*= o.eta end @@ -275,7 +275,7 @@ mutable struct InvDecay n::Int64 end -InvDecay(γ = 0.001, n = 0) = InvDecay(γ, n) +InvDecay(γ = 0.001) = InvDecay(γ, 0) function update!(o::InvDecay, x, Δ) γ, n = o.gamma, o.n @@ -288,7 +288,7 @@ mutable struct ExpDecay gamma::Float64 end -ExpDecay(γ = 0.001) = ExpDecay(γ) +ExpDecay() = ExpDecay(0.001) function update!(o::ExpDecay, x, Δ) γ = o.gamma @@ -300,7 +300,7 @@ mutable struct DescentWeightDecay gamma::Real end -DescentWeightDecay(η = 1, γ = 0) = DescentWeightDecay(η, γ) +DescentWeightDecay(η = 1) = DescentWeightDecay(η, 0) function update!(o::DescentWeightDecay, x, Δ) η, γ = o.eta, o.gamma @. x = x - η * (Δ + γ * x) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index f65ccb2a..a8a3b4a0 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -5,7 +5,7 @@ import Base.depwarn function update!(opt, xs) for x in xs x, Δ = data(x), grad(x) - update!(opt, x, Δ) + Δ = update!(opt, x, Δ) x .-= Δ Δ .= 0 end