127 lines
3.6 KiB
Julia
127 lines
3.6 KiB
Julia
using Base: depwarn
|
||
using Flux: Params
|
||
|
||
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
|
||
|
||
# legacy update rule
|
||
updaterule(opt, ps) = () -> _update_params!(opt, ps)
|
||
|
||
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
|
||
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
|
||
|
||
ps = params
|
||
opt = Descent(η)
|
||
opt = check_decay(opt, decay)
|
||
updaterule(opt, ps)
|
||
end
|
||
|
||
function Momentum(params::Union{AbstractArray, Params}, η = 0.01; ρ = 0.9, decay = 0.)
|
||
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
|
||
|
||
ps = params
|
||
opt = Momentum(η, ρ)
|
||
opt = check_decay(opt, decay)
|
||
updaterule(opt, ps)
|
||
end
|
||
|
||
function Nesterov(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
|
||
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
|
||
|
||
ps = params
|
||
opt = Nesterov(η, ρ)
|
||
opt = check_decay(opt, decay)
|
||
updaterule(opt, ps)
|
||
end
|
||
|
||
function RMSProp(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
|
||
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
|
||
|
||
ps = params
|
||
opt = RMSProp(η, ρ)
|
||
opt = check_decay(opt, decay)
|
||
updaterule(opt, ps)
|
||
end
|
||
|
||
function ADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
|
||
|
||
ps = params
|
||
β = (β1, β2)
|
||
opt = ADAM(η, β)
|
||
opt = check_decay(opt, decay)
|
||
updaterule(opt, ps)
|
||
end
|
||
|
||
function ADAGrad(params::Union{AbstractArray, Params}, η::Float64 = 0.1; decay = 0.)
|
||
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
|
||
|
||
ps = params
|
||
opt = ADAGrad(η)
|
||
opt = check_decay(opt, decay)
|
||
updaterule(opt, ps)
|
||
end
|
||
|
||
function ADADelta(params::Union{AbstractArray, Params}, ρ::Float64 = 0.9; decay = 0.)
|
||
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
|
||
|
||
ps = params
|
||
opt = ADADelta(ρ)
|
||
opt = check_decay(opt, decay)
|
||
updaterule(opt, ps)
|
||
end
|
||
|
||
function AdaMax(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
|
||
|
||
ps = params
|
||
β = (β1, β2)
|
||
opt = AdaMax(η, β)
|
||
opt = check_decay(opt, decay)
|
||
updaterule(opt, ps)
|
||
end
|
||
|
||
function AMSGrad(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
|
||
|
||
ps = params
|
||
β = (β1, β2)
|
||
opt = AMSGrad(η, β)
|
||
opt = check_decay(opt, decay)
|
||
updaterule(opt, ps)
|
||
end
|
||
|
||
function NADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
|
||
|
||
ps = params
|
||
β = (β1, β2)
|
||
opt = NADAM(η, β)
|
||
opt = check_decay(opt, decay)
|
||
updaterule(opt, ps)
|
||
end
|
||
|
||
function ADAMW(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
|
||
|
||
ps = params
|
||
β = (β1, β2)
|
||
opt = ADAMW(η, β)
|
||
opt = check_decay(opt, decay)
|
||
decay != 0 && (opt = Optimiser(opt, WeightDecay(decay)))
|
||
updaterule(opt, ps)
|
||
end
|
||
|
||
# Old training loop
|
||
|
||
struct OldOptimiser
|
||
func
|
||
end
|
||
|
||
_update_params!(opt::OldOptimiser, ps) = opt.func()
|
||
|
||
# Train function
|
||
function train!(loss, data, opt; cb = () -> ())
|
||
depwarn("train!(loss, data, opt) is deprecated; use train!(loss, params, data, opt) instead", :train!)
|
||
train!(loss, (), data, OldOptimiser(opt); cb = cb)
|
||
end
|