fix deprecations
This commit is contained in:
parent
b05cd41c99
commit
58a6c3f225
@ -1,11 +1,12 @@
|
|||||||
using Base: depwarn
|
using Base: depwarn
|
||||||
|
using Flux: Params
|
||||||
|
|
||||||
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
|
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
|
||||||
|
|
||||||
# legacy update rule
|
# legacy update rule
|
||||||
updaterule(opt, ps) = () -> update!(p, ps)
|
updaterule(opt, ps) = () -> update!(opt, ps)
|
||||||
|
|
||||||
function SGD(params::AbstractArray, η = 0.1; decay = 0.)
|
function SGD(params::Params, η = 0.1; decay = 0.)
|
||||||
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
|
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
|
||||||
|
|
||||||
ps = params
|
ps = params
|
||||||
@ -14,7 +15,7 @@ function SGD(params::AbstractArray, η = 0.1; decay = 0.)
|
|||||||
updaterule(opt, ps)
|
updaterule(opt, ps)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.)
|
function Momentum(params::Params, η = 0.01; ρ = 0.9, decay = 0.)
|
||||||
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
|
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
|
||||||
|
|
||||||
ps = params
|
ps = params
|
||||||
@ -23,7 +24,7 @@ function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.)
|
|||||||
updaterule(opt, ps)
|
updaterule(opt, ps)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
|
function Nesterov(params::Params, η = 0.001; ρ = 0.9, decay = 0.)
|
||||||
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
|
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
|
||||||
|
|
||||||
ps = params
|
ps = params
|
||||||
@ -32,7 +33,7 @@ function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
|
|||||||
updaterule(opt, ps)
|
updaterule(opt, ps)
|
||||||
end
|
end
|
||||||
|
|
||||||
function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
|
function RMSProp(params::Params, η = 0.001; ρ = 0.9, decay = 0.)
|
||||||
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
|
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
|
||||||
|
|
||||||
ps = params
|
ps = params
|
||||||
@ -41,7 +42,7 @@ function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
|
|||||||
updaterule(opt, ps)
|
updaterule(opt, ps)
|
||||||
end
|
end
|
||||||
|
|
||||||
function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
function ADAM(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||||
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
|
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
|
||||||
|
|
||||||
ps = params
|
ps = params
|
||||||
@ -51,7 +52,7 @@ function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay =
|
|||||||
updaterule(opt, ps)
|
updaterule(opt, ps)
|
||||||
end
|
end
|
||||||
|
|
||||||
function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.)
|
function ADAGrad(params::Params, η::Float64 = 0.1; decay = 0.)
|
||||||
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
|
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
|
||||||
|
|
||||||
ps = params
|
ps = params
|
||||||
@ -60,7 +61,7 @@ function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.)
|
|||||||
updaterule(opt, ps)
|
updaterule(opt, ps)
|
||||||
end
|
end
|
||||||
|
|
||||||
function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.)
|
function ADADelta(params::Params, ρ::Float64 = 0.9; decay = 0.)
|
||||||
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
|
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
|
||||||
|
|
||||||
ps = params
|
ps = params
|
||||||
@ -69,7 +70,7 @@ function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.)
|
|||||||
updaterule(opt, ps)
|
updaterule(opt, ps)
|
||||||
end
|
end
|
||||||
|
|
||||||
function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
function AdaMax(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||||
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
|
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
|
||||||
|
|
||||||
ps = params
|
ps = params
|
||||||
@ -79,7 +80,7 @@ function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay
|
|||||||
updaterule(opt, ps)
|
updaterule(opt, ps)
|
||||||
end
|
end
|
||||||
|
|
||||||
function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
function AMSGrad(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||||
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
|
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
|
||||||
|
|
||||||
ps = params
|
ps = params
|
||||||
@ -89,7 +90,7 @@ function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, deca
|
|||||||
updaterule(opt, ps)
|
updaterule(opt, ps)
|
||||||
end
|
end
|
||||||
|
|
||||||
function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
function NADAM(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||||
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
|
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
|
||||||
|
|
||||||
ps = params
|
ps = params
|
||||||
@ -99,14 +100,14 @@ function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay
|
|||||||
updaterule(opt, ps)
|
updaterule(opt, ps)
|
||||||
end
|
end
|
||||||
|
|
||||||
function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
function ADAMW(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||||
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
|
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
|
||||||
|
|
||||||
ps = params
|
ps = params
|
||||||
β = (β1, β2)
|
β = (β1, β2)
|
||||||
opt = ADAMW(η, β)
|
opt = ADAMW(η, β)
|
||||||
opt = check_decay(opt, decay)
|
opt = check_decay(opt, decay)
|
||||||
decay != 0 && (opt = Optimiser(opt, WeightDecay(decay)))
|
decay != 0 && (opt = Optimiser(opt, WeightDecay(η * decay)))
|
||||||
updaterule(opt, ps)
|
updaterule(opt, ps)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user