allow array parameters to old optimisers

This commit is contained in:
Dhairya Gandhi 2018-11-01 22:17:54 +05:30
parent ca4e01ac26
commit 5ec70fe29d

View File

@ -6,7 +6,7 @@ check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
# legacy update rule
updaterule(opt, ps) = () -> update!(opt, ps)
function SGD(params::Params, η = 0.1; decay = 0.)
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
ps = params
@ -15,7 +15,7 @@ function SGD(params::Params, η = 0.1; decay = 0.)
updaterule(opt, ps)
end
function Momentum(params::Params, η = 0.01; ρ = 0.9, decay = 0.)
function Momentum(params::Union{AbstractArray, Params}, η = 0.01; ρ = 0.9, decay = 0.)
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
ps = params
@ -24,7 +24,7 @@ function Momentum(params::Params, η = 0.01; ρ = 0.9, decay = 0.)
updaterule(opt, ps)
end
function Nesterov(params::Params, η = 0.001; ρ = 0.9, decay = 0.)
function Nesterov(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
ps = params
@ -33,7 +33,7 @@ function Nesterov(params::Params, η = 0.001; ρ = 0.9, decay = 0.)
updaterule(opt, ps)
end
function RMSProp(params::Params, η = 0.001; ρ = 0.9, decay = 0.)
function RMSProp(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
ps = params
@ -42,7 +42,7 @@ function RMSProp(params::Params, η = 0.001; ρ = 0.9, decay = 0.)
updaterule(opt, ps)
end
function ADAM(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
function ADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
ps = params
@ -52,7 +52,7 @@ function ADAM(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
updaterule(opt, ps)
end
function ADAGrad(params::Params, η::Float64 = 0.1; decay = 0.)
function ADAGrad(params::Union{AbstractArray, Params}, η::Float64 = 0.1; decay = 0.)
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
ps = params
@ -61,7 +61,7 @@ function ADAGrad(params::Params, η::Float64 = 0.1; decay = 0.)
updaterule(opt, ps)
end
function ADADelta(params::Params, ρ::Float64 = 0.9; decay = 0.)
function ADADelta(params::Union{AbstractArray, Params}, ρ::Float64 = 0.9; decay = 0.)
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
ps = params
@ -70,7 +70,7 @@ function ADADelta(params::Params, ρ::Float64 = 0.9; decay = 0.)
updaterule(opt, ps)
end
function AdaMax(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
function AdaMax(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
ps = params
@ -80,7 +80,7 @@ function AdaMax(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
updaterule(opt, ps)
end
function AMSGrad(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
function AMSGrad(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
ps = params
@ -90,7 +90,7 @@ function AMSGrad(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
updaterule(opt, ps)
end
function NADAM(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
function NADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
ps = params
@ -100,7 +100,7 @@ function NADAM(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
updaterule(opt, ps)
end
function ADAMW(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
function ADAMW(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
ps = params