From 58a6c3f225334698603d9a0f8c1dd7bd9bb5898e Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 1 Nov 2018 15:02:00 +0530 Subject: [PATCH 1/3] fix deprecations --- src/optimise/deprecations.jl | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/optimise/deprecations.jl b/src/optimise/deprecations.jl index 40c695b6..247c7a40 100644 --- a/src/optimise/deprecations.jl +++ b/src/optimise/deprecations.jl @@ -1,11 +1,12 @@ using Base: depwarn +using Flux: Params check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay)) # legacy update rule -updaterule(opt, ps) = () -> update!(p, ps) +updaterule(opt, ps) = () -> update!(opt, ps) -function SGD(params::AbstractArray, η = 0.1; decay = 0.) +function SGD(params::Params, η = 0.1; decay = 0.) depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD) ps = params @@ -14,7 +15,7 @@ function SGD(params::AbstractArray, η = 0.1; decay = 0.) updaterule(opt, ps) end -function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.) +function Momentum(params::Params, η = 0.01; ρ = 0.9, decay = 0.) depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum) ps = params @@ -23,7 +24,7 @@ function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.) updaterule(opt, ps) end -function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) +function Nesterov(params::Params, η = 0.001; ρ = 0.9, decay = 0.) depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov) ps = params @@ -32,7 +33,7 @@ function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) updaterule(opt, ps) end -function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) +function RMSProp(params::Params, η = 0.001; ρ = 0.9, decay = 0.) depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp) ps = params @@ -41,7 +42,7 @@ function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) updaterule(opt, ps) end -function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) +function ADAM(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM) ps = params @@ -51,7 +52,7 @@ function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = updaterule(opt, ps) end -function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.) +function ADAGrad(params::Params, η::Float64 = 0.1; decay = 0.) depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad) ps = params @@ -60,7 +61,7 @@ function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.) updaterule(opt, ps) end -function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.) +function ADADelta(params::Params, ρ::Float64 = 0.9; decay = 0.) depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta) ps = params @@ -69,7 +70,7 @@ function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.) updaterule(opt, ps) end -function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) +function AdaMax(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax) ps = params @@ -79,7 +80,7 @@ function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay updaterule(opt, ps) end -function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) +function AMSGrad(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad) ps = params @@ -89,7 +90,7 @@ function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, deca updaterule(opt, ps) end -function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) +function NADAM(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM) ps = params @@ -99,14 +100,14 @@ function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay updaterule(opt, ps) end -function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) +function ADAMW(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW) ps = params β = (β1, β2) opt = ADAMW(η, β) opt = check_decay(opt, decay) - decay != 0 && (opt = Optimiser(opt, WeightDecay(decay))) + decay != 0 && (opt = Optimiser(opt, WeightDecay(η * decay))) updaterule(opt, ps) end From ca4e01ac262609758805e07c96a5b62c800a7e05 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 1 Nov 2018 15:58:40 +0530 Subject: [PATCH 2/3] use user defined decay in ADAMW --- src/optimise/deprecations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimise/deprecations.jl b/src/optimise/deprecations.jl index 247c7a40..a90a6a79 100644 --- a/src/optimise/deprecations.jl +++ b/src/optimise/deprecations.jl @@ -107,7 +107,7 @@ function ADAMW(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) β = (β1, β2) opt = ADAMW(η, β) opt = check_decay(opt, decay) - decay != 0 && (opt = Optimiser(opt, WeightDecay(η * decay))) + decay != 0 && (opt = Optimiser(opt, WeightDecay(decay))) updaterule(opt, ps) end From 5ec70fe29d28e2c6a08791fd9efa95a06f946a99 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 1 Nov 2018 22:17:54 +0530 Subject: [PATCH 3/3] allow array parameters to old optimisers --- src/optimise/deprecations.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/optimise/deprecations.jl b/src/optimise/deprecations.jl index a90a6a79..34853bf6 100644 --- a/src/optimise/deprecations.jl +++ b/src/optimise/deprecations.jl @@ -6,7 +6,7 @@ check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay)) # legacy update rule updaterule(opt, ps) = () -> update!(opt, ps) -function SGD(params::Params, η = 0.1; decay = 0.) +function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.) depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD) ps = params @@ -15,7 +15,7 @@ function SGD(params::Params, η = 0.1; decay = 0.) updaterule(opt, ps) end -function Momentum(params::Params, η = 0.01; ρ = 0.9, decay = 0.) +function Momentum(params::Union{AbstractArray, Params}, η = 0.01; ρ = 0.9, decay = 0.) depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum) ps = params @@ -24,7 +24,7 @@ function Momentum(params::Params, η = 0.01; ρ = 0.9, decay = 0.) updaterule(opt, ps) end -function Nesterov(params::Params, η = 0.001; ρ = 0.9, decay = 0.) +function Nesterov(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.) depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov) ps = params @@ -33,7 +33,7 @@ function Nesterov(params::Params, η = 0.001; ρ = 0.9, decay = 0.) updaterule(opt, ps) end -function RMSProp(params::Params, η = 0.001; ρ = 0.9, decay = 0.) +function RMSProp(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.) depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp) ps = params @@ -42,7 +42,7 @@ function RMSProp(params::Params, η = 0.001; ρ = 0.9, decay = 0.) updaterule(opt, ps) end -function ADAM(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) +function ADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM) ps = params @@ -52,7 +52,7 @@ function ADAM(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) updaterule(opt, ps) end -function ADAGrad(params::Params, η::Float64 = 0.1; decay = 0.) +function ADAGrad(params::Union{AbstractArray, Params}, η::Float64 = 0.1; decay = 0.) depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad) ps = params @@ -61,7 +61,7 @@ function ADAGrad(params::Params, η::Float64 = 0.1; decay = 0.) updaterule(opt, ps) end -function ADADelta(params::Params, ρ::Float64 = 0.9; decay = 0.) +function ADADelta(params::Union{AbstractArray, Params}, ρ::Float64 = 0.9; decay = 0.) depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta) ps = params @@ -70,7 +70,7 @@ function ADADelta(params::Params, ρ::Float64 = 0.9; decay = 0.) updaterule(opt, ps) end -function AdaMax(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) +function AdaMax(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax) ps = params @@ -80,7 +80,7 @@ function AdaMax(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) updaterule(opt, ps) end -function AMSGrad(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) +function AMSGrad(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad) ps = params @@ -90,7 +90,7 @@ function AMSGrad(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) updaterule(opt, ps) end -function NADAM(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) +function NADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM) ps = params @@ -100,7 +100,7 @@ function NADAM(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) updaterule(opt, ps) end -function ADAMW(params::Params, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) +function ADAMW(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW) ps = params