Merge pull request #479 from dhairyagandhi96/master

Fix deprecations of optimisers
This commit is contained in:
Mike J Innes 2018-11-05 13:01:59 +00:00 committed by GitHub
commit 8042198475
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,11 +1,12 @@
using Base: depwarn using Base: depwarn
using Flux: Params
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay)) check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
# legacy update rule # legacy update rule
updaterule(opt, ps) = () -> update!(p, ps) updaterule(opt, ps) = () -> update!(opt, ps)
function SGD(params::AbstractArray, η = 0.1; decay = 0.) function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD) depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
ps = params ps = params
@ -14,7 +15,7 @@ function SGD(params::AbstractArray, η = 0.1; decay = 0.)
updaterule(opt, ps) updaterule(opt, ps)
end end
function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.) function Momentum(params::Union{AbstractArray, Params}, η = 0.01; ρ = 0.9, decay = 0.)
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum) depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
ps = params ps = params
@ -23,7 +24,7 @@ function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.)
updaterule(opt, ps) updaterule(opt, ps)
end end
function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) function Nesterov(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov) depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
ps = params ps = params
@ -32,7 +33,7 @@ function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
updaterule(opt, ps) updaterule(opt, ps)
end end
function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) function RMSProp(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp) depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
ps = params ps = params
@ -41,7 +42,7 @@ function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
updaterule(opt, ps) updaterule(opt, ps)
end end
function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) function ADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM) depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
ps = params ps = params
@ -51,7 +52,7 @@ function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay =
updaterule(opt, ps) updaterule(opt, ps)
end end
function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.) function ADAGrad(params::Union{AbstractArray, Params}, η::Float64 = 0.1; decay = 0.)
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad) depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
ps = params ps = params
@ -60,7 +61,7 @@ function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.)
updaterule(opt, ps) updaterule(opt, ps)
end end
function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.) function ADADelta(params::Union{AbstractArray, Params}, ρ::Float64 = 0.9; decay = 0.)
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta) depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
ps = params ps = params
@ -69,7 +70,7 @@ function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.)
updaterule(opt, ps) updaterule(opt, ps)
end end
function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) function AdaMax(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax) depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
ps = params ps = params
@ -79,7 +80,7 @@ function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay
updaterule(opt, ps) updaterule(opt, ps)
end end
function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) function AMSGrad(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad) depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
ps = params ps = params
@ -89,7 +90,7 @@ function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, deca
updaterule(opt, ps) updaterule(opt, ps)
end end
function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) function NADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM) depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
ps = params ps = params
@ -99,7 +100,7 @@ function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay
updaterule(opt, ps) updaterule(opt, ps)
end end
function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) function ADAMW(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW) depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
ps = params ps = params