diff --git a/src/optimise/deprecations.jl b/src/optimise/deprecations.jl index a90a6a79..34853bf6 100644 --- a/src/optimise/deprecations.jl +++ b/src/optimise/deprecations.jl @@ -6,7 +6,7 @@ check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay)) # legacy update rule updaterule(opt, ps) = () -> update!(opt, ps) -function SGD(params::Params, η = 0.1; decay = 0.) +function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.) depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD) ps = params @@ -15,7 +15,7 @@ function SGD(params::Params, η = 0.1; decay = 0.) updaterule(opt, ps) end -function Momentum(params::Params, η = 0.01; ρ = 0.9, decay = 0.) +function Momentum(params::Union{AbstractArray, Params}, η = 0.01; ρ = 0.9, decay = 0.) depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum) ps = params @@ -24,7 +24,7 @@ function Momentum(params::Params, η = 0.01; ρ = 0.9, decay = 0.) updaterule(opt, ps) end -function Nesterov(params::Params, η = 0.001; ρ = 0.9, decay = 0.) +function Nesterov(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.) depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov) ps = params @@ -33,7 +33,7 @@ function Nesterov(params::Params, η = 0.001; ρ = 0.9, decay = 0.) updaterule(opt, ps) end -function RMSProp(params::Params, η = 0.001; ρ = 0.9, decay = 0.) +function RMSProp(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.) depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp) ps = params @@ -42,7 +42,7 @@ function RMSProp(params::Params, η = 0.001; ρ = 0.9, decay = 0.) updaterule(opt, ps) end -function ADAM(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) +function ADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM) ps = params @@ -52,7 +52,7 @@ function ADAM(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) updaterule(opt, ps) end -function ADAGrad(params::Params, η::Float64 = 0.1; decay = 0.) +function ADAGrad(params::Union{AbstractArray, Params}, η::Float64 = 0.1; decay = 0.) depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad) ps = params @@ -61,7 +61,7 @@ function ADAGrad(params::Params, η::Float64 = 0.1; decay = 0.) updaterule(opt, ps) end -function ADADelta(params::Params, ρ::Float64 = 0.9; decay = 0.) +function ADADelta(params::Union{AbstractArray, Params}, ρ::Float64 = 0.9; decay = 0.) depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta) ps = params @@ -70,7 +70,7 @@ function ADADelta(params::Params, ρ::Float64 = 0.9; decay = 0.) updaterule(opt, ps) end -function AdaMax(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) +function AdaMax(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax) ps = params @@ -80,7 +80,7 @@ function AdaMax(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) updaterule(opt, ps) end -function AMSGrad(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) +function AMSGrad(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad) ps = params @@ -90,7 +90,7 @@ function AMSGrad(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) updaterule(opt, ps) end -function NADAM(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) +function NADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM) ps = params @@ -100,7 +100,7 @@ function NADAM(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) updaterule(opt, ps) end -function ADAMW(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) +function ADAMW(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW) ps = params