From a2d2d068aa0b60c228b0552de29981c273818ce1 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 31 May 2018 20:29:59 +0100 Subject: [PATCH 01/23] initial sketch --- src/Flux.jl | 4 +- src/optimise/Optimise.jl | 18 +-- src/optimise/interface.jl | 110 ------------------ src/optimise/optimisers.jl | 223 +++++++++++++++++++------------------ src/optimise/train.jl | 13 ++- src/tracker/Tracker.jl | 87 +++------------ 6 files changed, 148 insertions(+), 307 deletions(-) delete mode 100644 src/optimise/interface.jl diff --git a/src/Flux.jl b/src/Flux.jl index 614eeaf7..6ec849b0 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -19,8 +19,8 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs -export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, - RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM +export SGD, Descent, ADAM, AdaMax, Momentum, Nesterov, + RMSProp, ADAGrad, ADADelta, AMSGrad include("utils.jl") include("onehot.jl") diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index c4828c9e..d54e4453 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,23 +1,9 @@ module Optimise export train!, - SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, - RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException - -struct Param{T} - x::T - Δ::T -end - -Param(x::AbstractArray) = Param(x, zero(x)) + SGD, Descent, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad include("optimisers.jl") -include("interface.jl") include("train.jl") -using Flux.Tracker: TrackedArray - -Param(x::TrackedArray) = Param(x.data, x.grad) -# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad) - -end +end \ No newline at end of file diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl deleted file mode 100644 index 096e2d87..00000000 --- a/src/optimise/interface.jl +++ /dev/null @@ -1,110 +0,0 @@ -call(f, xs...) = f(xs...) - -# note for optimisers: set to zero -# p.Δ at the end of the weights update -function optimiser(ps, fs...) - ps = [Param(p) for p in ps] - fs = map(ps) do p - os = map(f -> f(p), fs) - () -> foreach(call, os) - end - () -> foreach(call, fs) -end - -""" - SGD(params, η = 0.1; decay = 0) - -Classic gradient descent optimiser with learning rate `η`. -For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`. - -Supports inverse decaying learning rate if the `decay` argument is provided. -""" -SGD(ps, η = 0.1; decay = 0) = - optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η)) - -""" - Momentum(params, η = 0.01; ρ = 0.9, decay = 0) - -SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay. -""" -Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) = - optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1)) - -""" - Nesterov(params, η = 0.01; ρ = 0.9, decay = 0) - -SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay. -""" -Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) = - optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1)) - -""" - RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) - -[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) -optimiser. Parameters other than learning rate don't need tuning. Often a good -choice for recurrent networks. -""" -RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) = - optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) - -""" - ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) - -[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser. -""" -ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) - -""" - ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) - -[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. -""" -ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay)) - -""" - AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) - -[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on -the ∞-norm. -""" -AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) - -""" - ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0) - -[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. -Parameters don't need tuning. -""" -ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) = - optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) - -""" - ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0) - -[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need -tuning. -""" -ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) = - optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1)) - -""" - AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) - -[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need -tuning. -""" -AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) - -""" - NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) - -[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other -than learning rate don't need tuning. -""" -NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 1f7a7c9c..cfbbcfe9 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -1,130 +1,139 @@ -function descent(p::Param, η::Real) - function () - @. p.x -= η * p.Δ - @. p.Δ = 0 - end +using Flux +using Base: @get! + +const ϵ = 1e-8 + +# TODO: should use weak refs + +""" + Descent(η) + +Classic gradient descent optimiser with learning rate `η`. +For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`. +""" +mutable struct Descent + eta::Float64 end -# Ref: https://arxiv.org/abs/1711.05101.pdf -function descentweightdecay(p::Param, η::Real, γ::Real) - function () - @. p.x = p.x - η * (p.Δ + γ * p.x) - @. p.Δ = 0 - end +function update!(o::Descent, x, Δ) + Δ .*= o.eta end -function momentum(p::Param, ρ, η) - v = zero(p.x) - function () - @. v = ρ * v - η * p.Δ - @. p.Δ = -v - end +""" + Momentum(params, η = 0.01; ρ = 0.9, decay = 0) + +Gradient descent with learning rate `η` and momentum `ρ`. +""" +mutable struct Momentum + eta::Float64 + rho::Float64 + velocity::ObjectIdDict end -# Ref. https://arxiv.org/pdf/1212.0901.pdf -function nesterov(p::Param, ρ, η) - v = zero(p.x) - function () - d = @. ρ^2 * v - (1+ρ) * η * p.Δ - @. v = ρ*v - η*p.Δ - @. p.Δ = -d - end +Momentum(η, ρ = 0.9) = Momentum(η, ρ, ObjectIdDict()) + +function update!(o::Momentum, x, Δ) + η, ρ = o.eta, o.rho + v = @get!(o.velocity, x, zero(x))::typeof(x) + @. v = ρ * v - η * Δ + @. Δ = -v end -function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) - acc = zero(p.x) - function () - @. acc = ρ * acc + (1 - ρ) * p.Δ^2 - @. p.Δ *= η / √(acc + ϵ) - end +""" + Nesterov(eta, ρ = 0.9) + +Gradient descent with learning rate `η` and Nesterov momentum `ρ`. +""" +mutable struct Nesterov + eta::Float64 + rho::Float64 + velocity::ObjectIdDict end -function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) - acc = zero(p.x) .+ ϵ - function () - @. acc += p.Δ^2 - @. p.Δ *= η / √(acc + ϵ) - end +Nesterov(η, ρ = 0.9) = Nesterov(η, ρ, ObjectIdDict()) + +function update!(o::Nesterov, x, Δ) + η, ρ = o.eta, o.rho + v = @get!(o.velocity, x, zero(x))::typeof(x) + d = @. ρ^2 * v - (1+ρ) * η * Δ + @. v = ρ*v - η*Δ + @. Δ = -d end -function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8) - acc = zero(p.x) - Δacc = zero(p.x) - function () - @. acc = ρ * acc + (1 - ρ) * p.Δ^2 - @. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ) - @. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2 - end +""" + RMSProp(η = 0.001, ρ = 0.9) + +[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) +optimiser. Parameters other than learning rate don't need tuning. Often a good +choice for recurrent networks. +""" +mutable struct RMSProp + eta::Float64 + rho::Float64 + acc::ObjectIdDict end -function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) - mt = zero(p.x) - vt = zero(p.x) - β1p, β2p = β1, β2 - function () - @. mt = β1 * mt + (1 - β1) * p.Δ - @. vt = β2 * vt + (1 - β2) * p.Δ^2 - @. p.Δ = mt / (1 - β1p) / √(vt / (1 - β2p) + ϵ) * η - β1p *= β1 - β2p *= β2 - end +RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, ObjectIdDict()) + +function update!(o::RMSProp, x, Δ) + η, ρ = o.eta, o.rho + acc = @get!(o.acc, x, zero(x))::typeof(x) + @. acc = ρ * acc + (1 - ρ) * Δ^2 + @. Δ *= η / (√acc + ϵ) end -function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) - mt = zero(p.x) - ut = zero(p.x) - β1p = β1 - function () - @. mt = β1 * mt + (1 - β1) * p.Δ - @. ut = max(β2 * ut, abs(p.Δ)) - @. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ) - β1p *= β1 - end +""" + ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + +[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser. +""" +mutable struct ADAM + eta::Float64 + beta::Tuple{Float64,Float64} + state::ObjectIdDict end -function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) - mt = zero(p.x) - vt = zero(p.x) .+ ϵ - v̂t = zero(p.x) .+ ϵ - function () - @. mt = β1 * mt + (1 - β1) * p.Δ - @. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 - @. v̂t = max.(v̂t, vt) - @. p.Δ = η * mt / √v̂t - end +ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, ObjectIdDict()) + +function update!(o::ADAM, x, Δ) + η, β = o.eta, o.beta + mt, vt, βp = @get!(o.state, x, (zero(x), zero(x), β)) + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. vt = β[2] * vt + (1 - β[2]) * Δ^2 + @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η + o.state[x] = (mt, vt, βp .* β) end -function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) - mt = zero(p.x) - vt = zero(p.x) - β1p, β2p = β1, β2 - function () - @. mt = β1 * mt + (1 - β1) * p.Δ - @. vt = β2 * vt + (1 - β2) * p.Δ^2 - @. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / √(vt * β2 / (1 - β2p) + ϵ) * η - β1p *= β1 - β2p *= β2 - end -end +# """ +# AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) +# +# [AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on +# the ∞-norm. +# """ -clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh) +# """ +# ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0) +# +# [ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. +# Parameters don't need tuning. +# """ -function expdecay(p::Param, γ::Real) - if γ != 0 - return () -> p.Δ .+= γ .* p.x - else - return () -> nothing - end -end +# """ +# ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0) +# +# [ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need +# tuning. +# """ -function invdecay(p::Param, γ::Real) - if γ != 0 - n = 0 - return () -> begin - p.Δ .*= 1 / (1 + γ * n) - n += 1 - end - else - return () -> nothing - end -end +# """ +# AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) +# +# [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need +# tuning. +# """ + +# struct Optimiser +# os::Vector{Any} +# end + +# TODO: decay diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 09893873..85c402e6 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,7 +1,16 @@ using Juno -using Flux.Tracker: back! -import Base.depwarn +using Flux.Tracker: data, grad, back! +function update!(opt, xs) + for x in xs + x, Δ = data(x), grad(x) + update!(opt, x, Δ) + x .-= Δ + Δ .= 0 + end +end + +# Callback niceties runall(f) = f runall(fs::AbstractVector) = () -> foreach(call, fs) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 190837ab..036c0904 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -1,27 +1,23 @@ module Tracker -using MacroTools -using MacroTools: @q, @forward - import Base: == -export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back! +export TrackedArray, TrackedVector, TrackedMatrix, param, back! tracker(x) = nothing istracked(x) = tracker(x) ≠ nothing isleaf(x) = !istracked(x) || isleaf(tracker(x)) +data(x) = istracked(x) ? data(tracker(x)) : x grad(x) = grad(tracker(x)) grad(::Nothing) = nothing -data(x) = x struct Call{F,As<:Tuple} func::F args::As end -Call(f::F, args::T) where {F,T} = Call{F,T}(f, args) -Call() = Call(nothing, ()) +Call(f, args...) = Call{typeof(f),typeof(args)}(f, args) # When deserialising, the object_id changes a::Call == b::Call = a.func == b.func && a.args == b.args @@ -32,86 +28,37 @@ mutable struct Tracked{T} ref::UInt32 f::Call isleaf::Bool + data::T grad::T - Tracked{T}(f::Call) where T = new(0, f, false) - Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad) - Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad) + Tracked{T}(f::Call, data::T) where T = new(0, f, false, data) + Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad) + Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad) end +Tracked(f::Call, x) = Tracked{typeof(x)}(f, x) +Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ) + +track(f::Call, x) = Tracked(f, x) +track(f::Call) = track(f, f()) +track(f, xs...) = track(Call(f, xs...)) + istracked(x::Tracked) = true -isleaf(x::Tracked) = x.f == Call() +isleaf(x::Tracked) = x.f == Call(nothing) +data(x::Tracked) = x.data grad(x::Tracked) = x.grad -track(f::Call, x) = Tracked{typeof(x)}(f) - -function _forward end - -function track(f::F, xs...; kw...) where F - y, back = _forward(f, xs...; kw...) - track(Call(back, tracker.(xs)), y) -end - -macro grad(ex) - @capture(shortdef(ex), (name_(args__) = body_) | - (name_(args__) where {T__} = body_)) || error("Need a function definition") - T == nothing && (T = []) - isexpr(name, :(::)) || (name = :(::typeof($name))) - insert!(args, 1+isexpr(args[1], :parameters) , name) - @q(Tracker._forward($(args...)) where $(T...) = $body) |> esc -end - -function update!(x, Δ) - x.data .+= data(Δ) - tracker(x).grad .= 0 - return x -end - -include("idset.jl") include("back.jl") include("scalar.jl") include("array.jl") include("numeric.jl") -""" - hook(f, x) -> x′ - -Hook into gradient backpropagation. `x` is unmodified, but when backpropagating -`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse -the sign of the gradient applied to `x`.""" -hook(f, x) = istracked(x) ? track(hook, f, x) : x -@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ)) - -""" - checkpoint(f, args...) - -Behaves like `f(args...)`, but avoids storing the intermediate values needed for -calculating gradients. Instead, `f(args...)` will be called again during the -backward pass. This can be used to save memory in larger models. -""" -checkpoint(f, args...) = track(checkpoint, f, args...) - -@grad function checkpoint(f, args...) - data(f(args...)), function (Δ) - y, back = forward(f, args...) - (nothing, back(Δ)...) - end -end - -nobacksies(f, x) = track(nobacksies, f, x) -nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs) -@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f") - param(x::Number) = TrackedReal(float(x)) param(xs::AbstractArray) = TrackedArray(float.(xs)) -@grad identity(x) = data(x), Δ -> (Δ,) -param(x::TrackedReal) = track(identity, x) -param(x::TrackedArray) = track(identity, x) - import NNlib.cudata import Adapt.adapt cudata(x::TrackedArray) = data(x) adapt(T, xs::TrackedArray) = param(adapt(T, data(xs))) -end +end \ No newline at end of file From d933f2079b50865d6c19ffd88cc5823f627b1e92 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 11 Sep 2018 18:30:24 +0530 Subject: [PATCH 02/23] pulled tracker from upstream --- src/Flux.jl | 4 +- src/optimise/Optimise.jl | 2 +- src/optimise/optimisers.jl | 24 +++++------ src/tracker/Tracker.jl | 83 ++++++++++++++++++++++++++++++-------- src/tracker/array.jl | 10 ++--- test/optimise.jl | 24 ++++++----- 6 files changed, 100 insertions(+), 47 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index e8ca9d75..e684be56 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -19,8 +19,8 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs -export SGD, Descent, ADAM, AdaMax, Momentum, Nesterov, - RMSProp, ADAGrad, ADADelta, AMSGrad +export Descent, ADAM, Momentum, Nesterov, + RMSProp, update! include("utils.jl") include("onehot.jl") diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index d54e4453..c8abcf3d 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,7 @@ module Optimise export train!, - SGD, Descent, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad + Descent, ADAM, Momentum, Nesterov, RMSProp include("optimisers.jl") include("train.jl") diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index cfbbcfe9..ce04fe5a 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -27,14 +27,14 @@ Gradient descent with learning rate `η` and momentum `ρ`. mutable struct Momentum eta::Float64 rho::Float64 - velocity::ObjectIdDict + velocity::IdDict end -Momentum(η, ρ = 0.9) = Momentum(η, ρ, ObjectIdDict()) +Momentum(η, ρ = 0.9) = Momentum(η, ρ, IdDict()) function update!(o::Momentum, x, Δ) η, ρ = o.eta, o.rho - v = @get!(o.velocity, x, zero(x))::typeof(x) + v = get!(o.velocity, x, zero(x))::typeof(x) @. v = ρ * v - η * Δ @. Δ = -v end @@ -47,14 +47,14 @@ Gradient descent with learning rate `η` and Nesterov momentum `ρ`. mutable struct Nesterov eta::Float64 rho::Float64 - velocity::ObjectIdDict + velocity::IdDict end -Nesterov(η, ρ = 0.9) = Nesterov(η, ρ, ObjectIdDict()) +Nesterov(η, ρ = 0.9) = Nesterov(η, ρ, IdDict()) function update!(o::Nesterov, x, Δ) η, ρ = o.eta, o.rho - v = @get!(o.velocity, x, zero(x))::typeof(x) + v = get!(o.velocity, x, zero(x))::typeof(x) d = @. ρ^2 * v - (1+ρ) * η * Δ @. v = ρ*v - η*Δ @. Δ = -d @@ -70,14 +70,14 @@ choice for recurrent networks. mutable struct RMSProp eta::Float64 rho::Float64 - acc::ObjectIdDict + acc::IdDict end -RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, ObjectIdDict()) +RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict()) function update!(o::RMSProp, x, Δ) η, ρ = o.eta, o.rho - acc = @get!(o.acc, x, zero(x))::typeof(x) + acc = get!(o.acc, x, zero(x))::typeof(x) @. acc = ρ * acc + (1 - ρ) * Δ^2 @. Δ *= η / (√acc + ϵ) end @@ -90,14 +90,14 @@ end mutable struct ADAM eta::Float64 beta::Tuple{Float64,Float64} - state::ObjectIdDict + state::IdDict end -ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, ObjectIdDict()) +ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict()) function update!(o::ADAM, x, Δ) η, β = o.eta, o.beta - mt, vt, βp = @get!(o.state, x, (zero(x), zero(x), β)) + mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β)) @. mt = β[1] * mt + (1 - β[1]) * Δ @. vt = β[2] * vt + (1 - β[2]) * Δ^2 @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 036c0904..e51b464e 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -1,23 +1,27 @@ module Tracker +using MacroTools +using MacroTools: @q, @forward + import Base: == -export TrackedArray, TrackedVector, TrackedMatrix, param, back! +export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back! tracker(x) = nothing istracked(x) = tracker(x) ≠ nothing isleaf(x) = !istracked(x) || isleaf(tracker(x)) -data(x) = istracked(x) ? data(tracker(x)) : x grad(x) = grad(tracker(x)) grad(::Nothing) = nothing +data(x) = x struct Call{F,As<:Tuple} func::F args::As end -Call(f, args...) = Call{typeof(f),typeof(args)}(f, args) +Call(f::F, args::T) where {F,T} = Call{F,T}(f, args) +Call() = Call(nothing, ()) # When deserialising, the object_id changes a::Call == b::Call = a.func == b.func && a.args == b.args @@ -28,33 +32,80 @@ mutable struct Tracked{T} ref::UInt32 f::Call isleaf::Bool - data::T grad::T - Tracked{T}(f::Call, data::T) where T = new(0, f, false, data) - Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad) - Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad) + Tracked{T}(f::Call) where T = new(0, f, false) + Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad) + Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad) end -Tracked(f::Call, x) = Tracked{typeof(x)}(f, x) -Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ) - -track(f::Call, x) = Tracked(f, x) -track(f::Call) = track(f, f()) -track(f, xs...) = track(Call(f, xs...)) - istracked(x::Tracked) = true -isleaf(x::Tracked) = x.f == Call(nothing) -data(x::Tracked) = x.data +isleaf(x::Tracked) = x.f == Call() grad(x::Tracked) = x.grad +track(f::Call, x) = Tracked{typeof(x)}(f) + +function _forward end + +function track(f::F, xs...; kw...) where F + y, back = _forward(f, xs...; kw...) + track(Call(back, tracker.(xs)), y) +end + +macro grad(ex) + @capture(shortdef(ex), (name_(args__) = body_) | + (name_(args__) where {T__} = body_)) || error("Need a function definition") + T == nothing && (T = []) + isexpr(name, :(::)) || (name = :(::typeof($name))) + insert!(args, 1+isexpr(args[1], :parameters) , name) + @q(Tracker._forward($(args...)) where $(T...) = $body) |> esc +end + +function update!(x, Δ) + x.data .+= data(Δ) + tracker(x).grad .= 0 + return x +end + +include("idset.jl") include("back.jl") include("scalar.jl") include("array.jl") include("numeric.jl") +""" + hook(f, x) -> x′ +Hook into gradient backpropagation. `x` is unmodified, but when backpropagating +`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse +the sign of the gradient applied to `x`.""" +hook(f, x) = istracked(x) ? track(hook, f, x) : x +@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ)) + +""" + checkpoint(f, args...) +Behaves like `f(args...)`, but avoids storing the intermediate values needed for +calculating gradients. Instead, `f(args...)` will be called again during the +backward pass. This can be used to save memory in larger models. +""" +checkpoint(f, args...) = track(checkpoint, f, args...) + +@grad function checkpoint(f, args...) + data(f(args...)), function (Δ) + y, back = forward(f, args...) + (nothing, back(Δ)...) + end +end + +nobacksies(f, x) = track(nobacksies, f, x) +nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs) +@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f") + param(x::Number) = TrackedReal(float(x)) param(xs::AbstractArray) = TrackedArray(float.(xs)) +@grad identity(x) = data(x), Δ -> (Δ,) +param(x::TrackedReal) = track(identity, x) +param(x::TrackedArray) = track(identity, x) + import NNlib.cudata import Adapt.adapt diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 882a866c..202a2ca2 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -87,14 +87,11 @@ Base.adjoint(xs::TrackedArray) = track(adjoint, xs) @grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),) @grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),) - Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...) - @grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs))) repeat(data(xs), inner = inner, outer = outer), function (Δ) Δ′ = zero(xs) S = size(xs) - # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′ for (dest_idx, val) in pairs(IndexCartesian(), data(Δ)) # First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then @@ -105,7 +102,6 @@ Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...) (nobacksies(:repeat, Δ′),) end end - for f in [:vcat, :hcat] UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose}) @eval begin @@ -361,7 +357,7 @@ end track(Call(back, tracker.(args)), y) end -using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted +using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted, cat_nested struct TrackedStyle <: BroadcastStyle end @@ -385,6 +381,10 @@ end using Requires +Base.Broadcast.cat_nested(t::Base.Broadcast.Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...) +Base.Broadcast.cat_nested(t::Any, rest...) = (t, cat_nested(rest...)...) +Base.Broadcast.cat_nested() = () + # https://github.com/FluxML/Flux.jl/issues/353 @init Requires.isprecompiling() || @eval Base.Broadcast begin function flatten(bc::Broadcasted{Style}) where {Style} diff --git a/test/optimise.jl b/test/optimise.jl index 502d9ab2..3d864143 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -3,16 +3,18 @@ using Flux.Tracker using Test @testset "Optimise" begin w = randn(10, 10) - @testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM] - w′ = param(randn(10, 10)) - loss(x) = Flux.mse(w*x, w′*x) - opt = Opt([w′]) - for t=1:10^5 - l = loss(rand(10)) - back!(l) - opt() - end - @test Flux.mse(w, w′) < 0.01 + @testset for Opt in [Descent, Nesterov, RMSProp, ADAM, Momentum] + w′ = param(randn(10, 10)) + delta = param(Tracker.similar(w′)) + loss(x) = Flux.mse(w*x, w′*x) + opt = Opt(0.1) + for t=1:10^5 + l = loss(rand(10)) + back!(l) + update!(opt, w′.data, delta.data) + w′ .-= delta + end + @test Flux.mse(w, w′) < 0.01 end end @@ -23,7 +25,7 @@ end Flux.train!(() -> (sleep(0.1); i += 1; l), Iterators.repeated((), 100), ()->(), - cb = Flux.throttle(() -> (i > 3 && stop()), 1)) + cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1)) @test 3 < i < 50 end From 4860c1d48badc83b7d82447e3e429f457a1af62d Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 11 Sep 2018 18:35:21 +0530 Subject: [PATCH 03/23] fixed white lines --- src/tracker/array.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 202a2ca2..85dbdc41 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -87,11 +87,14 @@ Base.adjoint(xs::TrackedArray) = track(adjoint, xs) @grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),) @grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),) + Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...) + @grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs))) repeat(data(xs), inner = inner, outer = outer), function (Δ) Δ′ = zero(xs) S = size(xs) + # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′ for (dest_idx, val) in pairs(IndexCartesian(), data(Δ)) # First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then @@ -102,6 +105,7 @@ Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...) (nobacksies(:repeat, Δ′),) end end + for f in [:vcat, :hcat] UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose}) @eval begin From 63bc71698b355128b08d4a0740ac62638bfd36ec Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 14 Sep 2018 20:32:56 +0530 Subject: [PATCH 04/23] updated tests --- src/optimise/Optimise.jl | 2 +- src/optimise/optimisers.jl | 1 + src/optimise/train.jl | 1 + src/tracker/Tracker.jl | 2 ++ test/optimise.jl | 24 +++++++++++++----------- 5 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index c8abcf3d..ac53ba25 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,7 @@ module Optimise export train!, - Descent, ADAM, Momentum, Nesterov, RMSProp + Descent, ADAM, Momentum, Nesterov, RMSProp, stop, StopException include("optimisers.jl") include("train.jl") diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index ce04fe5a..08ce1631 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -102,6 +102,7 @@ function update!(o::ADAM, x, Δ) @. vt = β[2] * vt + (1 - β[2]) * Δ^2 @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η o.state[x] = (mt, vt, βp .* β) + return Δ end # """ diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 85c402e6..f65ccb2a 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,5 +1,6 @@ using Juno using Flux.Tracker: data, grad, back! +import Base.depwarn function update!(opt, xs) for x in xs diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index e51b464e..3cd03c1f 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -74,6 +74,7 @@ include("numeric.jl") """ hook(f, x) -> x′ + Hook into gradient backpropagation. `x` is unmodified, but when backpropagating `f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse the sign of the gradient applied to `x`.""" @@ -82,6 +83,7 @@ hook(f, x) = istracked(x) ? track(hook, f, x) : x """ checkpoint(f, args...) + Behaves like `f(args...)`, but avoids storing the intermediate values needed for calculating gradients. Instead, `f(args...)` will be called again during the backward pass. This can be used to save memory in larger models. diff --git a/test/optimise.jl b/test/optimise.jl index 3d864143..f61ed822 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -3,18 +3,20 @@ using Flux.Tracker using Test @testset "Optimise" begin w = randn(10, 10) - @testset for Opt in [Descent, Nesterov, RMSProp, ADAM, Momentum] - w′ = param(randn(10, 10)) - delta = param(Tracker.similar(w′)) - loss(x) = Flux.mse(w*x, w′*x) + @testset for Opt in [Descent, ADAM, Nesterov, RMSProp, Momentum] + w′ = param(randn(10, 10)) + loss(x) = Flux.mse(w*x, w′*x) + opt = Opt(0.001) + if opt isa Descent opt = Opt(0.1) - for t=1:10^5 - l = loss(rand(10)) - back!(l) - update!(opt, w′.data, delta.data) - w′ .-= delta - end - @test Flux.mse(w, w′) < 0.01 + end + for t = 1: 10^5 + l = loss(rand(10)) + back!(l) + delta = Optimise.update!(opt, w′.data, w′.grad) + w′.data .-= delta + end + @test Flux.mse(w, w′) < 0.01 end end From 6665189ff11de1bbf03cb2cba7ea2062324adf95 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sun, 16 Sep 2018 17:34:51 +0530 Subject: [PATCH 05/23] added remaining optimizers and tests --- src/Flux.jl | 5 +- src/optimise/Optimise.jl | 4 +- src/optimise/optimisers.jl | 176 +++++++++++++++++++++++++++++++------ src/tracker/array.jl | 6 +- test/optimise.jl | 22 ++++- 5 files changed, 174 insertions(+), 39 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index e684be56..0fb4d08a 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -19,8 +19,9 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs -export Descent, ADAM, Momentum, Nesterov, - RMSProp, update! +export Descent, ADAM, Momentum, Nesterov, RMSProp, + ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, + InvDecay, ExpDecay include("utils.jl") include("onehot.jl") diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index ac53ba25..76b90311 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,9 @@ module Optimise export train!, - Descent, ADAM, Momentum, Nesterov, RMSProp, stop, StopException + Descent, ADAM, Momentum, Nesterov, RMSProp, + ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, + InvDecay, ExpDecay, stop, StopException, Compose include("optimisers.jl") include("train.jl") diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 08ce1631..18d8336b 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -20,7 +20,7 @@ function update!(o::Descent, x, Δ) end """ - Momentum(params, η = 0.01; ρ = 0.9, decay = 0) + Momentum(params, η = 0.01; ρ = 0.9) Gradient descent with learning rate `η` and momentum `ρ`. """ @@ -83,7 +83,7 @@ function update!(o::RMSProp, x, Δ) end """ - ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + ADAM(η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08) [ADAM](https://arxiv.org/abs/1412.6980v8) optimiser. """ @@ -105,36 +105,154 @@ function update!(o::ADAM, x, Δ) return Δ end -# """ -# AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) -# -# [AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on -# the ∞-norm. -# """ +""" + AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08) -# """ -# ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0) -# -# [ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. -# Parameters don't need tuning. -# """ +[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on +the ∞-norm. +""" +mutable struct AdaMax + eta::Float64 + beta::Tuple{Float64,Float64} + state::IdDict +end -# """ -# ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0) -# -# [ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need -# tuning. -# """ +AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict()) -# """ -# AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) -# -# [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need -# tuning. -# """ +function update!(o::AdaMax, x, Δ) + η, β = o.eta, o.beta + mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β)) + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. ut = max(β[2] * ut, abs(Δ)) + @. Δ = (η/(1 - βp[1])) * mt/(ut + ϵ) + o.state[x] = (mt, ut, βp .* β) + return Δ +end -# struct Optimiser -# os::Vector{Any} -# end +""" + ADAGrad(η = 0.1; ϵ = 1e-8) + +[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. +Parameters don't need tuning. +""" +mutable struct ADAGrad + eta::Float64 + acc::IdDict +end + +ADAGrad(η = 0.1) = ADAGrad(η, IdDict()) + +function update!(o::ADAGrad, x, Δ) + η = o.eta + acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x) + @. acc += Δ^2 + @. Δ *= η / √(acc + ϵ) +end + +""" + ADADelta(params; ρ = 0.9, ϵ = 1e-8) + +[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need +tuning. +""" +mutable struct ADADelta + rho::Float64 + state::IdDict +end + +ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict()) + +function update!(o::ADADelta, x, Δ) + ρ = o.rho + acc, Δacc = get!(o.state, x, (zero(x), zero(x))) + @. acc = ρ * acc + (1 - ρ) * Δ^2 + @. Δ *= √(Δacc + ϵ) / √(acc + ϵ) + @. Δacc = ρ * Δacc + (1 - ρ) * Δ^2 + return Δ +end + +""" + AMSGrad(η = 0.001, β = (0.9, 0.999)) + +[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need +tuning. +""" +mutable struct AMSGrad + eta::Float64 + beta::Tuple{Float64, Float64} + state::IdDict +end + +AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict()) + +function update!(o::AMSGrad, x, Δ) + η, β = o.eta, o.beta + mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x)))) + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2 + @. v̂t = max.(v̂t, vt) + @. Δ = η * mt / √v̂t +end + +""" + NADAM(η = 0.001, β = (0.9, 0.999)) + +[NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser. Parameters don't need +tuning. +""" +mutable struct NADAM + eta::Float64 + beta::Tuple{Float64, Float64} + state::IdDict +end + +NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict()) + +function update!(o::NADAM, x, Δ) + η, β = o.eta, o.beta + β1p, β2p = o.beta + mt, vt = get!(o.state, x, (zero(x), zero(x))) + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. vt = β[2] * vt + (1 - β[2]) * Δ^2 + @. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / √(vt * β[2] / (1 - β2p) + ϵ) * η + o.state[x] = (mt, vt, (β1p * β[1], β2p * β[2])) + return Δ +end + +mutable struct Compose + os::Vector{Any} +end + +function update!(o::Compose, x, Δ) + for opt in o.os + Δ = update!(opt, x, Δ) + end + return Δ +end # TODO: decay + +mutable struct InvDecay + gamma::Float64 + n::Int64 +end + +InvDecay(γ = 0.001, n = 0) = InvDecay(γ, n) + +function update!(o::InvDecay, x, Δ) + γ, n = o.gamma, o.n + Δ .*= 1 / (1 + γ * n) + o.n += 1 + return Δ +end + +mutable struct ExpDecay + gamma::Float64 +end + +ExpDecay(γ = 0.001) = ExpDecay(γ) + +function update!(o::ExpDecay, x, Δ) + γ = o.gamma + @. Δ += γ * x +end diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 85dbdc41..882a866c 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -361,7 +361,7 @@ end track(Call(back, tracker.(args)), y) end -using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted, cat_nested +using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted struct TrackedStyle <: BroadcastStyle end @@ -385,10 +385,6 @@ end using Requires -Base.Broadcast.cat_nested(t::Base.Broadcast.Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...) -Base.Broadcast.cat_nested(t::Any, rest...) = (t, cat_nested(rest...)...) -Base.Broadcast.cat_nested() = () - # https://github.com/FluxML/Flux.jl/issues/353 @init Requires.isprecompiling() || @eval Base.Broadcast begin function flatten(bc::Broadcasted{Style}) where {Style} diff --git a/test/optimise.jl b/test/optimise.jl index f61ed822..a85e8976 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -3,13 +3,16 @@ using Flux.Tracker using Test @testset "Optimise" begin w = randn(10, 10) - @testset for Opt in [Descent, ADAM, Nesterov, RMSProp, Momentum] + @testset for Opt in [ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum] w′ = param(randn(10, 10)) loss(x) = Flux.mse(w*x, w′*x) opt = Opt(0.001) - if opt isa Descent + if opt isa Descent || opt isa ADAGrad opt = Opt(0.1) end + if opt isa ADADelta + opt = Opt(0.9) + end for t = 1: 10^5 l = loss(rand(10)) back!(l) @@ -20,6 +23,21 @@ using Test end end +@testset "Compose" begin + w = randn(10, 10) + @testset for Opt in [InvDecay, ExpDecay] + w′ = param(randn(10, 10)) + loss(x) = Flux.mse(w*x, w′*x) + opt = Compose(vec([Opt(), ADAM(0.001)])) + for t = 1:10^5 + l = loss(rand(10)) + back!(l) + delta = Optimise.update!(opt, w′.data, w′.grad) + w′.data .-= delta + end + @test Flux.mse(w, w′) < 0.01 +end + @testset "Training Loop" begin i = 0 l = param(1) From 87c7e65a2dc5a1b1d2270a6db06c135cc0eafa6a Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sun, 16 Sep 2018 17:45:29 +0530 Subject: [PATCH 06/23] fixed Compose test --- test/optimise.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/optimise.jl b/test/optimise.jl index a85e8976..ed56e2a2 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -36,6 +36,7 @@ end w′.data .-= delta end @test Flux.mse(w, w′) < 0.01 + end end @testset "Training Loop" begin From b661db37974751c26986b1ef8f1992e2e452191c Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 1 Oct 2018 05:30:53 +0530 Subject: [PATCH 07/23] added deprecations and compose --- src/optimise/Optimise.jl | 3 +- src/optimise/deprecations.jl | 128 +++++++++++++++++++++++++++++++++++ src/optimise/optimisers.jl | 54 ++++++++++++++- 3 files changed, 182 insertions(+), 3 deletions(-) create mode 100644 src/optimise/deprecations.jl diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 76b90311..b6f18532 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -2,10 +2,11 @@ module Optimise export train!, Descent, ADAM, Momentum, Nesterov, RMSProp, - ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, + ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, InvDecay, ExpDecay, stop, StopException, Compose include("optimisers.jl") include("train.jl") +include("deprecations.jl") end \ No newline at end of file diff --git a/src/optimise/deprecations.jl b/src/optimise/deprecations.jl new file mode 100644 index 00000000..6a297619 --- /dev/null +++ b/src/optimise/deprecations.jl @@ -0,0 +1,128 @@ +using Base: depwarn + +function check_decay(opt, decay) + if decay == 0. + opt = opt + else + if opt isa ADAMW + opt = Compose(opt, DescentWeightDecay(1, decay)) + else + opt = Compose(opt, InvDecay(decay)) + end + end + opt +end + +# legacy update rule +function updaterule(opt, ps) + () -> begin + for p in ps + delta = update!(opt, p.data, p.grad) + p.data .-= delta + end + end +end + +function Descent(params::AbstractArray, η = 0.1; decay = 0.) + depwarn("Descent(ps::Param) is deprecated; use Descent(η::Float64) instead", :Descent) + + ps = params + opt = Descent(η) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.) + depwarn("Momentum(ps::Param) is deprecated; use Momentum(η::Float64) instead", :Momentum) + + ps = params + opt = Momentum(η, ρ) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) + depwarn("Nesterov(ps::Param) is deprecated; use Nesterov(η::Float64) instead", :Nesterov) + + ps = params + opt = Nesterov(η, ρ) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) + depwarn("RMSProp(ps::Param) is deprecated; use RMSProp(η::Float64) instead", :RMSProp) + + ps = params + opt = RMSProp(η, ρ) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) + depwarn("ADAM(ps::Param) is deprecated; use ADAM(η::Float64) instead", :ADAM) + + ps = params + β = (β1, β2) + opt = ADAM(η, β) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.) + depwarn("ADAGrad(ps::Param) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad) + + ps = params + opt = ADAGrad(η) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.) + depwarn("ADADelta(ps::Param) is deprecated; use ADADelta(η::Float64) instead", :ADADelta) + + ps = params + opt = ADADelta(ρ) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) + depwarn("AdaMax(ps::Param) is deprecated; use AdaMax(η::Float64) instead", :AdaMax) + + ps = params + β = (β1, β2) + opt = AdaMax(η, β) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) + depwarn("AMSGrad(ps::Param) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad) + + ps = params + β = (β1, β2) + opt = AMSGrad(η, β) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) + depwarn("NADAM(ps::Param) is deprecated; use NADAM(η::Float64) instead", :NADAM) + + ps = params + β = (β1, β2) + opt = NADAM(η, β) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end + +function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) + depwarn("ADAMW(ps::Param) is deprecated; use ADAMW(η::Float64) instead", :ADAMW) + + ps = params + β = (β1, β2) + opt = ADAMW(η, β) + opt = check_decay(opt, decay) + updaterule(opt, ps) +end \ No newline at end of file diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 18d8336b..4005db4f 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -1,5 +1,6 @@ using Flux using Base: @get! +using MacroTools: @forward const ϵ = 1e-8 @@ -15,6 +16,7 @@ mutable struct Descent eta::Float64 end +Descent(η = 0.1) = Descent(η) function update!(o::Descent, x, Δ) Δ .*= o.eta end @@ -30,7 +32,7 @@ mutable struct Momentum velocity::IdDict end -Momentum(η, ρ = 0.9) = Momentum(η, ρ, IdDict()) +Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict()) function update!(o::Momentum, x, Δ) η, ρ = o.eta, o.rho @@ -50,7 +52,7 @@ mutable struct Nesterov velocity::IdDict end -Nesterov(η, ρ = 0.9) = Nesterov(η, ρ, IdDict()) +Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict()) function update!(o::Nesterov, x, Δ) η, ρ = o.eta, o.rho @@ -219,10 +221,46 @@ function update!(o::NADAM, x, Δ) return Δ end +""" + ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + +[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. +""" +ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Compose(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, γ_decay)) + +# Compose optimizers + +""" + `Compose(Compose(...), ...)` + +Compose optimizers to support inbuilt or custom gradient updates while fitting the loss. + +Example:\n\n +`Compose(ADAM(), Compose(RMSProp(0.001), ExpDecay(0.02)))` +""" mutable struct Compose os::Vector{Any} end +Compose(o...) = Compose(flattenCompose(o...)) + +@forward Compose.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex! +@forward Compose.os Base.iterate + +Base.getindex(c::Compose, i::AbstractArray) = Compose(c.os[i]...) + +function flattenCompose(o...) + res = [] + for opt in o + if opt isa Compose + push!(res, flattenCompose(opt.os...)...) + else + push!(res, opt) + end + end + return res +end + function update!(o::Compose, x, Δ) for opt in o.os Δ = update!(opt, x, Δ) @@ -256,3 +294,15 @@ function update!(o::ExpDecay, x, Δ) γ = o.gamma @. Δ += γ * x end + +mutable struct DescentWeightDecay + eta::Real + gamma::Real +end + +DescentWeightDecay(η = 1, γ = 0) = DescentWeightDecay(η, γ) +function update!(o::DescentWeightDecay, x, Δ) + η, γ = o.eta, o.gamma + @. x = x - η * (Δ + γ * x) + Δ +end From 4abe5185990d06f67bc298d7c69d4d060bcd0644 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Fri, 5 Oct 2018 12:37:47 +0100 Subject: [PATCH 08/23] newline fixes --- src/optimise/Optimise.jl | 2 +- src/optimise/deprecations.jl | 2 +- src/tracker/Tracker.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index b6f18532..873a3ece 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -9,4 +9,4 @@ include("optimisers.jl") include("train.jl") include("deprecations.jl") -end \ No newline at end of file +end diff --git a/src/optimise/deprecations.jl b/src/optimise/deprecations.jl index 6a297619..b8aac8c0 100644 --- a/src/optimise/deprecations.jl +++ b/src/optimise/deprecations.jl @@ -125,4 +125,4 @@ function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay opt = ADAMW(η, β) opt = check_decay(opt, decay) updaterule(opt, ps) -end \ No newline at end of file +end diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 3cd03c1f..190837ab 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -114,4 +114,4 @@ import Adapt.adapt cudata(x::TrackedArray) = data(x) adapt(T, xs::TrackedArray) = param(adapt(T, data(xs))) -end \ No newline at end of file +end From 9bc9771a8dc807da6dc278a6634d7c732b0b1193 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Fri, 5 Oct 2018 12:43:03 +0100 Subject: [PATCH 09/23] tweaks --- src/optimise/Optimise.jl | 2 +- src/optimise/optimisers.jl | 8 ++++---- src/optimise/train.jl | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 873a3ece..4c5c8290 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -3,7 +3,7 @@ module Optimise export train!, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, - InvDecay, ExpDecay, stop, StopException, Compose + InvDecay, ExpDecay, stop, Compose include("optimisers.jl") include("train.jl") diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 4005db4f..ae30445a 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -16,7 +16,7 @@ mutable struct Descent eta::Float64 end -Descent(η = 0.1) = Descent(η) +Descent() = Descent(0.1) function update!(o::Descent, x, Δ) Δ .*= o.eta end @@ -275,7 +275,7 @@ mutable struct InvDecay n::Int64 end -InvDecay(γ = 0.001, n = 0) = InvDecay(γ, n) +InvDecay(γ = 0.001) = InvDecay(γ, 0) function update!(o::InvDecay, x, Δ) γ, n = o.gamma, o.n @@ -288,7 +288,7 @@ mutable struct ExpDecay gamma::Float64 end -ExpDecay(γ = 0.001) = ExpDecay(γ) +ExpDecay() = ExpDecay(0.001) function update!(o::ExpDecay, x, Δ) γ = o.gamma @@ -300,7 +300,7 @@ mutable struct DescentWeightDecay gamma::Real end -DescentWeightDecay(η = 1, γ = 0) = DescentWeightDecay(η, γ) +DescentWeightDecay(η = 1) = DescentWeightDecay(η, 0) function update!(o::DescentWeightDecay, x, Δ) η, γ = o.eta, o.gamma @. x = x - η * (Δ + γ * x) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index f65ccb2a..a8a3b4a0 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -5,7 +5,7 @@ import Base.depwarn function update!(opt, xs) for x in xs x, Δ = data(x), grad(x) - update!(opt, x, Δ) + Δ = update!(opt, x, Δ) x .-= Δ Δ .= 0 end From 0f2019eba5d2f2c61e90c5594f13954d9cff0f3f Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Fri, 5 Oct 2018 12:57:03 +0100 Subject: [PATCH 10/23] compose tweaks --- src/optimise/optimisers.jl | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index ae30445a..c3db9959 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -222,7 +222,7 @@ function update!(o::NADAM, x, Δ) end """ - ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) [ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. """ @@ -231,36 +231,22 @@ ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Compose(ADAM( # Compose optimizers """ - `Compose(Compose(...), ...)` + Compose(a, b, c...) -Compose optimizers to support inbuilt or custom gradient updates while fitting the loss. - -Example:\n\n -`Compose(ADAM(), Compose(RMSProp(0.001), ExpDecay(0.02)))` +Combine several optimisers into one; each optimiser produces a modified gradient +that will be fed into the next, and this is finally applied to the parameter as +usual. """ mutable struct Compose os::Vector{Any} + Compose(o...) = Compose(Any[o...]) end -Compose(o...) = Compose(flattenCompose(o...)) - @forward Compose.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex! @forward Compose.os Base.iterate Base.getindex(c::Compose, i::AbstractArray) = Compose(c.os[i]...) -function flattenCompose(o...) - res = [] - for opt in o - if opt isa Compose - push!(res, flattenCompose(opt.os...)...) - else - push!(res, opt) - end - end - return res -end - function update!(o::Compose, x, Δ) for opt in o.os Δ = update!(opt, x, Δ) From bfe85e65f11fbd9ddc581ebf488cb7d472484171 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Fri, 5 Oct 2018 13:52:26 +0100 Subject: [PATCH 11/23] compose tweaks --- src/optimise/optimisers.jl | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index ae30445a..2d62cf26 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -222,7 +222,7 @@ function update!(o::NADAM, x, Δ) end """ - ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) [ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. """ @@ -231,36 +231,22 @@ ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Compose(ADAM( # Compose optimizers """ - `Compose(Compose(...), ...)` + Compose(a, b, c...) -Compose optimizers to support inbuilt or custom gradient updates while fitting the loss. - -Example:\n\n -`Compose(ADAM(), Compose(RMSProp(0.001), ExpDecay(0.02)))` +Combine several optimisers into one; each optimiser produces a modified gradient +that will be fed into the next, and this is finally applied to the parameter as +usual. """ mutable struct Compose os::Vector{Any} + Compose(o...) = new(Any[o...]) end -Compose(o...) = Compose(flattenCompose(o...)) - @forward Compose.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex! @forward Compose.os Base.iterate Base.getindex(c::Compose, i::AbstractArray) = Compose(c.os[i]...) -function flattenCompose(o...) - res = [] - for opt in o - if opt isa Compose - push!(res, flattenCompose(opt.os...)...) - else - push!(res, opt) - end - end - return res -end - function update!(o::Compose, x, Δ) for opt in o.os Δ = update!(opt, x, Δ) From fe8c147f725969c63c147e8a078e44202c403b5a Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 11 Oct 2018 10:07:16 +0530 Subject: [PATCH 12/23] fixed weight decay definition --- src/Flux.jl | 2 +- src/optimise/Optimise.jl | 2 +- src/optimise/deprecations.jl | 10 ++++++++-- src/optimise/optimisers.jl | 34 ++++++++++++++++++---------------- src/optimise/train.jl | 9 +++++---- test/optimise.jl | 4 ++-- 6 files changed, 35 insertions(+), 26 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 0fb4d08a..b09cda17 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -21,7 +21,7 @@ using .Optimise using .Optimise: @epochs export Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, - InvDecay, ExpDecay + ADAMW, InvDecay, ExpDecay, WeightDecay, DescentWeightDecay include("utils.jl") include("onehot.jl") diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 4c5c8290..cf12a3c3 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -3,7 +3,7 @@ module Optimise export train!, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, - InvDecay, ExpDecay, stop, Compose + InvDecay, ExpDecay, WeightDecay, stop, Optimiser include("optimisers.jl") include("train.jl") diff --git a/src/optimise/deprecations.jl b/src/optimise/deprecations.jl index b8aac8c0..979eaebc 100644 --- a/src/optimise/deprecations.jl +++ b/src/optimise/deprecations.jl @@ -5,9 +5,9 @@ function check_decay(opt, decay) opt = opt else if opt isa ADAMW - opt = Compose(opt, DescentWeightDecay(1, decay)) + opt = Optimiser(opt, DescentWeightDecay(1, decay)) else - opt = Compose(opt, InvDecay(decay)) + opt = Optimiser(opt, InvDecay(decay)) end end opt @@ -126,3 +126,9 @@ function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay opt = check_decay(opt, decay) updaterule(opt, ps) end + +# Train function +function train!(loss::Function, data, opt; cb = () -> ()) + depwarn("train!(loss, data, opt; cb) is deprecated; use train!(model, data, loss, opt; cb) instead", :train) + train!(opt.ps, loss, data, opt.opt; cb = cb) +end \ No newline at end of file diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index c3db9959..119732da 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -85,7 +85,7 @@ function update!(o::RMSProp, x, Δ) end """ - ADAM(η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08) + ADAM(η = 0.001, β = (0.9, 0.999)) [ADAM](https://arxiv.org/abs/1412.6980v8) optimiser. """ @@ -226,28 +226,29 @@ end [ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. """ -ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Compose(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, γ_decay)) +ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Optimiser(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, γ_decay)) # Compose optimizers """ - Compose(a, b, c...) + Optimiser(a, b, c...) Combine several optimisers into one; each optimiser produces a modified gradient that will be fed into the next, and this is finally applied to the parameter as usual. """ -mutable struct Compose +mutable struct Optimiser os::Vector{Any} - Compose(o...) = Compose(Any[o...]) end -@forward Compose.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex! -@forward Compose.os Base.iterate +Optimiser(o...) = Optimiser(Any[o...]) -Base.getindex(c::Compose, i::AbstractArray) = Compose(c.os[i]...) +@forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex! +@forward Optimiser.os Base.iterate -function update!(o::Compose, x, Δ) +Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...) + +function update!(o::Optimiser, x, Δ) for opt in o.os Δ = update!(opt, x, Δ) end @@ -281,14 +282,15 @@ function update!(o::ExpDecay, x, Δ) @. Δ += γ * x end -mutable struct DescentWeightDecay +mutable struct WeightDecay eta::Real - gamma::Real + wd::Real end -DescentWeightDecay(η = 1) = DescentWeightDecay(η, 0) -function update!(o::DescentWeightDecay, x, Δ) - η, γ = o.eta, o.gamma - @. x = x - η * (Δ + γ * x) - Δ +WeightDecay(η = 1) = WeightDecay(η, 0) +function update!(o::WeightDecay, x, Δ) + η, wd = o.eta, o.wd + @. Δ += wd * x end + +DescentWeightDecay(η = 0.1, γ = 0) = Optimiser(WeightDecay(), Descent(η)) \ No newline at end of file diff --git a/src/optimise/train.jl b/src/optimise/train.jl index a8a3b4a0..2fbe6b85 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -45,7 +45,7 @@ function stop() end """ - train!(loss, data, opt) + train!(model, loss, data, opt) For each datapoint `d` in `data` computes the gradient of `loss(d...)` through backpropagation and calls the optimizer `opt`. @@ -54,7 +54,7 @@ Takes a callback as keyword argument `cb`. For example, this will print "trainin every 10 seconds: ```julia -Flux.train!(loss, data, opt, +Flux.train!(model, loss, data, opt, cb = throttle(() -> println("training"), 10)) ``` @@ -62,14 +62,14 @@ The callback can return `:stop` to interrupt the training loop. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ -function train!(loss, data, opt; cb = () -> ()) +function train!(ps::Array, loss, data, opt; cb = () -> ()) cb = runall(cb) opt = runall(opt) @progress for d in data try l = loss(d...) @interrupts back!(l) - opt() + foreach(x -> x.data .-= update!(opt, x.data, x.grad), ps) if cb() == :stop depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop) break @@ -83,6 +83,7 @@ function train!(loss, data, opt; cb = () -> ()) end end end +train!(model, loss, data, opt; cb = () -> ()) = train!(params(model), loss, data, opt; cb = cb) """ @epochs N body diff --git a/test/optimise.jl b/test/optimise.jl index ed56e2a2..b2e3f13b 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -23,12 +23,12 @@ using Test end end -@testset "Compose" begin +@testset "Optimiser" begin w = randn(10, 10) @testset for Opt in [InvDecay, ExpDecay] w′ = param(randn(10, 10)) loss(x) = Flux.mse(w*x, w′*x) - opt = Compose(vec([Opt(), ADAM(0.001)])) + opt = Optimiser(Opt(), ADAM(0.001)) for t = 1:10^5 l = loss(rand(10)) back!(l) From 1f0f2a5ac26e466bf0dc05b1340172688b0b5c00 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 11 Oct 2018 10:21:29 +0530 Subject: [PATCH 13/23] fixed DescentWeightDecay parameters --- src/optimise/optimisers.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 58e4e7df..02dbb547 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -226,7 +226,7 @@ end [ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. """ -ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Optimiser(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, γ_decay)) +ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, wd = 0) = Optimiser(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, wd)) # Compose optimizers @@ -292,4 +292,4 @@ function update!(o::WeightDecay, x, Δ) @. Δ += wd * x end -DescentWeightDecay(η = 0.1, γ = 0) = Optimiser(WeightDecay(), Descent(η)) \ No newline at end of file +DescentWeightDecay(η = 1, wd = 0) = Optimiser(WeightDecay(1, wd), Descent(η)) \ No newline at end of file From edbcd3c9ea530d0d385408104353b56a4e92fd2f Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 11 Oct 2018 18:52:16 +0530 Subject: [PATCH 14/23] fix train! test --- test/optimise.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/optimise.jl b/test/optimise.jl index b2e3f13b..0cbcf413 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -45,7 +45,7 @@ end Flux.train!(() -> (sleep(0.1); i += 1; l), Iterators.repeated((), 100), - ()->(), + ADAM([l]), cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1)) @test 3 < i < 50 From 815e8c206d1b3f75f5aa86cda9461ec95225d6d9 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sat, 27 Oct 2018 19:26:42 +0530 Subject: [PATCH 15/23] decay fixes --- src/optimise/deprecations.jl | 12 ++++++++---- src/optimise/optimisers.jl | 38 ++++++++++++++++++++++++------------ src/optimise/train.jl | 16 +++++++++------ test/optimise.jl | 10 ++++++---- 4 files changed, 50 insertions(+), 26 deletions(-) diff --git a/src/optimise/deprecations.jl b/src/optimise/deprecations.jl index 979eaebc..228c3c29 100644 --- a/src/optimise/deprecations.jl +++ b/src/optimise/deprecations.jl @@ -5,7 +5,7 @@ function check_decay(opt, decay) opt = opt else if opt isa ADAMW - opt = Optimiser(opt, DescentWeightDecay(1, decay)) + opt = Optimiser(opt, WeightDecay(decay)) else opt = Optimiser(opt, InvDecay(decay)) end @@ -129,6 +129,10 @@ end # Train function function train!(loss::Function, data, opt; cb = () -> ()) - depwarn("train!(loss, data, opt; cb) is deprecated; use train!(model, data, loss, opt; cb) instead", :train) - train!(opt.ps, loss, data, opt.opt; cb = cb) -end \ No newline at end of file + depwarn("train!(loss, data, opt; cb) is deprecated; use train!(loss, params, data, opt; cb) instead", :train) + if fieldnames(typeof(opt)) !== () + train!(loss, opt.ps, data, opt.opt; cb = cb) + else + train!(loss, (), data, opt; cb = cb) + end +end diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 02dbb547..f6590bdb 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -258,38 +258,52 @@ end mutable struct InvDecay gamma::Float64 - n::Int64 + state::IdDict end -InvDecay(γ = 0.001) = InvDecay(γ, 0) +InvDecay(γ = 0.001) = InvDecay(γ, IdDict()) function update!(o::InvDecay, x, Δ) - γ, n = o.gamma, o.n + γ = o.gamma + n = get!(o.state, x, 1) Δ .*= 1 / (1 + γ * n) - o.n += 1 + o.state[x] = n + 1 return Δ end mutable struct ExpDecay - gamma::Float64 + opt + decay::Float64 + step::Int64 + clip::Float64 + current::IdDict end -ExpDecay() = ExpDecay(0.001) +ExpDecay(opt, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict()) function update!(o::ExpDecay, x, Δ) - γ = o.gamma - @. Δ += γ * x + s, decay = o.step, o.decay + η = try o.opt.eta; catch e; o.opt.rho; end + n = o.current[x] = get(o.current, x, 0) + 1 + flag = false + count(x -> x%s == 0, values(o.current)) == 1 && (flag = true) + if o.current[x]%s == 0 && flag + η = max(η * decay^(s / n), o.clip) + o.opt isa ADADelta ? o.opt.rho = η : o.opt.eta = η + end + update!(o.opt, x, Δ) end mutable struct WeightDecay - eta::Real wd::Real end -WeightDecay(η = 1) = WeightDecay(η, 0) +WeightDecay() = WeightDecay(0) function update!(o::WeightDecay, x, Δ) - η, wd = o.eta, o.wd + wd = o.wd @. Δ += wd * x end -DescentWeightDecay(η = 1, wd = 0) = Optimiser(WeightDecay(1, wd), Descent(η)) \ No newline at end of file +DescentWeightDecay(η = 1, wd = 0) = Optimiser(WeightDecay(wd), Descent(η)) + +update!(opt::Function, ps) = opt() diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 2fbe6b85..9fe459f6 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -4,9 +4,8 @@ import Base.depwarn function update!(opt, xs) for x in xs - x, Δ = data(x), grad(x) - Δ = update!(opt, x, Δ) - x .-= Δ + Δ = update!(opt, x.data, x.grad) + x.data .-= Δ Δ .= 0 end end @@ -62,14 +61,20 @@ The callback can return `:stop` to interrupt the training loop. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ -function train!(ps::Array, loss, data, opt; cb = () -> ()) +function train!(loss, ps, data, opt; cb = () -> ()) cb = runall(cb) opt = runall(opt) + opt = try + opt() + opt.opt + catch + opt + end @progress for d in data try l = loss(d...) @interrupts back!(l) - foreach(x -> x.data .-= update!(opt, x.data, x.grad), ps) + update!(opt, ps) if cb() == :stop depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop) break @@ -83,7 +88,6 @@ function train!(ps::Array, loss, data, opt; cb = () -> ()) end end end -train!(model, loss, data, opt; cb = () -> ()) = train!(params(model), loss, data, opt; cb = cb) """ @epochs N body diff --git a/test/optimise.jl b/test/optimise.jl index 0cbcf413..14d02224 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -16,7 +16,7 @@ using Test for t = 1: 10^5 l = loss(rand(10)) back!(l) - delta = Optimise.update!(opt, w′.data, w′.grad) + delta = Optimise.update!(opt, w′) w′.data .-= delta end @test Flux.mse(w, w′) < 0.01 @@ -25,14 +25,16 @@ end @testset "Optimiser" begin w = randn(10, 10) - @testset for Opt in [InvDecay, ExpDecay] + @testset for Opt in [InvDecay, WeightDecay, ExpDecay] w′ = param(randn(10, 10)) loss(x) = Flux.mse(w*x, w′*x) opt = Optimiser(Opt(), ADAM(0.001)) + if Opt isa ExpDecay + opt = ExpDecay(ADAM(), 0.9) for t = 1:10^5 l = loss(rand(10)) back!(l) - delta = Optimise.update!(opt, w′.data, w′.grad) + delta = Optimise.update!(opt, w′) w′.data .-= delta end @test Flux.mse(w, w′) < 0.01 @@ -45,7 +47,7 @@ end Flux.train!(() -> (sleep(0.1); i += 1; l), Iterators.repeated((), 100), - ADAM([l]), + () -> (), cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1)) @test 3 < i < 50 From ea508a79b007d094ac6b49212bd18a539cbac23d Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sat, 27 Oct 2018 19:39:56 +0530 Subject: [PATCH 16/23] use explicit update! rule --- test/optimise.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/optimise.jl b/test/optimise.jl index 14d02224..fa59cb2d 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -16,7 +16,7 @@ using Test for t = 1: 10^5 l = loss(rand(10)) back!(l) - delta = Optimise.update!(opt, w′) + delta = Optimise.update!(opt, w′.data, w′.grad) w′.data .-= delta end @test Flux.mse(w, w′) < 0.01 From 32ce2d78b8483e5553ec05107dc022d586ac5491 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sat, 27 Oct 2018 19:53:06 +0530 Subject: [PATCH 17/23] fixed ExpDecay test --- src/optimise/optimisers.jl | 2 +- test/optimise.jl | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index f6590bdb..24f66267 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -279,7 +279,7 @@ mutable struct ExpDecay current::IdDict end -ExpDecay(opt, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict()) +ExpDecay(opt = Descent(), decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict()) function update!(o::ExpDecay, x, Δ) s, decay = o.step, o.decay diff --git a/test/optimise.jl b/test/optimise.jl index fa59cb2d..f97d06f8 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -30,11 +30,12 @@ end loss(x) = Flux.mse(w*x, w′*x) opt = Optimiser(Opt(), ADAM(0.001)) if Opt isa ExpDecay - opt = ExpDecay(ADAM(), 0.9) + opt = ExpDecay(ADAM(), 0.9, 1000) + end for t = 1:10^5 l = loss(rand(10)) back!(l) - delta = Optimise.update!(opt, w′) + delta = Optimise.update!(opt, w′.data, w′.grad) w′.data .-= delta end @test Flux.mse(w, w′) < 0.01 From bebf4eb95f63fcd2946160b64b93fdba00c1f61f Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 29 Oct 2018 23:12:24 +0530 Subject: [PATCH 18/23] fixed ExpDecay update! rule --- src/optimise/optimisers.jl | 11 +++++------ test/optimise.jl | 3 --- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 24f66267..8881ffb0 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -272,26 +272,25 @@ function update!(o::InvDecay, x, Δ) end mutable struct ExpDecay - opt + eta::Float64 decay::Float64 step::Int64 clip::Float64 current::IdDict end -ExpDecay(opt = Descent(), decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict()) +ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict()) function update!(o::ExpDecay, x, Δ) - s, decay = o.step, o.decay - η = try o.opt.eta; catch e; o.opt.rho; end + η, s, decay = o.eta, o.step, o.decay n = o.current[x] = get(o.current, x, 0) + 1 flag = false count(x -> x%s == 0, values(o.current)) == 1 && (flag = true) if o.current[x]%s == 0 && flag η = max(η * decay^(s / n), o.clip) - o.opt isa ADADelta ? o.opt.rho = η : o.opt.eta = η + o.eta = η end - update!(o.opt, x, Δ) + @. Δ *= decay end mutable struct WeightDecay diff --git a/test/optimise.jl b/test/optimise.jl index f97d06f8..78510a94 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -29,9 +29,6 @@ end w′ = param(randn(10, 10)) loss(x) = Flux.mse(w*x, w′*x) opt = Optimiser(Opt(), ADAM(0.001)) - if Opt isa ExpDecay - opt = ExpDecay(ADAM(), 0.9, 1000) - end for t = 1:10^5 l = loss(rand(10)) back!(l) From bffaceee029dc40ee934825d0bf30ca119190190 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 31 Oct 2018 14:58:55 +0000 Subject: [PATCH 19/23] tweaks --- src/Flux.jl | 2 +- src/optimise/deprecations.jl | 56 ++++++++++++++++-------------------- src/optimise/optimisers.jl | 20 ++++++------- src/optimise/train.jl | 6 ---- test/optimise.jl | 3 +- 5 files changed, 36 insertions(+), 51 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index b09cda17..7c72cbbc 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -21,7 +21,7 @@ using .Optimise using .Optimise: @epochs export Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, - ADAMW, InvDecay, ExpDecay, WeightDecay, DescentWeightDecay + ADAMW, InvDecay, ExpDecay, WeightDecay include("utils.jl") include("onehot.jl") diff --git a/src/optimise/deprecations.jl b/src/optimise/deprecations.jl index 228c3c29..8529799f 100644 --- a/src/optimise/deprecations.jl +++ b/src/optimise/deprecations.jl @@ -1,17 +1,6 @@ using Base: depwarn -function check_decay(opt, decay) - if decay == 0. - opt = opt - else - if opt isa ADAMW - opt = Optimiser(opt, WeightDecay(decay)) - else - opt = Optimiser(opt, InvDecay(decay)) - end - end - opt -end +check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay)) # legacy update rule function updaterule(opt, ps) @@ -24,7 +13,7 @@ function updaterule(opt, ps) end function Descent(params::AbstractArray, η = 0.1; decay = 0.) - depwarn("Descent(ps::Param) is deprecated; use Descent(η::Float64) instead", :Descent) + depwarn("Descent(params) is deprecated; use Descent(η::Float64) instead", :Descent) ps = params opt = Descent(η) @@ -33,7 +22,7 @@ function Descent(params::AbstractArray, η = 0.1; decay = 0.) end function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.) - depwarn("Momentum(ps::Param) is deprecated; use Momentum(η::Float64) instead", :Momentum) + depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum) ps = params opt = Momentum(η, ρ) @@ -42,7 +31,7 @@ function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.) end function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) - depwarn("Nesterov(ps::Param) is deprecated; use Nesterov(η::Float64) instead", :Nesterov) + depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov) ps = params opt = Nesterov(η, ρ) @@ -51,7 +40,7 @@ function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) end function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) - depwarn("RMSProp(ps::Param) is deprecated; use RMSProp(η::Float64) instead", :RMSProp) + depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp) ps = params opt = RMSProp(η, ρ) @@ -60,7 +49,7 @@ function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.) end function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) - depwarn("ADAM(ps::Param) is deprecated; use ADAM(η::Float64) instead", :ADAM) + depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM) ps = params β = (β1, β2) @@ -70,7 +59,7 @@ function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = end function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.) - depwarn("ADAGrad(ps::Param) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad) + depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad) ps = params opt = ADAGrad(η) @@ -79,7 +68,7 @@ function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.) end function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.) - depwarn("ADADelta(ps::Param) is deprecated; use ADADelta(η::Float64) instead", :ADADelta) + depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta) ps = params opt = ADADelta(ρ) @@ -88,7 +77,7 @@ function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.) end function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) - depwarn("AdaMax(ps::Param) is deprecated; use AdaMax(η::Float64) instead", :AdaMax) + depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax) ps = params β = (β1, β2) @@ -98,7 +87,7 @@ function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay end function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) - depwarn("AMSGrad(ps::Param) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad) + depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad) ps = params β = (β1, β2) @@ -108,7 +97,7 @@ function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, deca end function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) - depwarn("NADAM(ps::Param) is deprecated; use NADAM(η::Float64) instead", :NADAM) + depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM) ps = params β = (β1, β2) @@ -118,21 +107,26 @@ function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay end function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.) - depwarn("ADAMW(ps::Param) is deprecated; use ADAMW(η::Float64) instead", :ADAMW) + depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW) ps = params β = (β1, β2) opt = ADAMW(η, β) opt = check_decay(opt, decay) + decay != 0 && (opt = Optimiser(opt, WeightDecay(decay))) updaterule(opt, ps) end -# Train function -function train!(loss::Function, data, opt; cb = () -> ()) - depwarn("train!(loss, data, opt; cb) is deprecated; use train!(loss, params, data, opt; cb) instead", :train) - if fieldnames(typeof(opt)) !== () - train!(loss, opt.ps, data, opt.opt; cb = cb) - else - train!(loss, (), data, opt; cb = cb) - end +# Old training loop + +struct OldOptimiser + func +end + +update!(opt::OldOptimiser, ps) = opt.func() + +# Train function +function train!(loss, data, opt; cb = () -> ()) + depwarn("train!(loss, data, opt) is deprecated; use train!(loss, params, data, opt) instead", :train!) + train!(loss, (), data, OldOptimiser(opt); cb = cb) end diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 8881ffb0..2accc4bc 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -17,6 +17,7 @@ mutable struct Descent end Descent() = Descent(0.1) + function update!(o::Descent, x, Δ) Δ .*= o.eta end @@ -152,7 +153,7 @@ function update!(o::ADAGrad, x, Δ) end """ - ADADelta(params; ρ = 0.9, ϵ = 1e-8) + ADADelta(ρ = 0.9, ϵ = 1e-8) [ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need tuning. @@ -222,16 +223,18 @@ function update!(o::NADAM, x, Δ) end """ - ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + ADAMW((η = 0.001, β = (0.9, 0.999), decay = 0) [ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. """ -ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, wd = 0) = Optimiser(ADAM(η, β, IdDict()), DescentWeightDecay(η_decay, wd)) +ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) = + Optimiser(ADAM(η, β), WeightDecay(wd)) # Compose optimizers """ Optimiser(a, b, c...) + Combine several optimisers into one; each optimiser produces a modified gradient that will be fed into the next, and this is finally applied to the parameter as usual. @@ -254,8 +257,6 @@ function update!(o::Optimiser, x, Δ) return Δ end -# TODO: decay - mutable struct InvDecay gamma::Float64 state::IdDict @@ -284,9 +285,7 @@ ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(op function update!(o::ExpDecay, x, Δ) η, s, decay = o.eta, o.step, o.decay n = o.current[x] = get(o.current, x, 0) + 1 - flag = false - count(x -> x%s == 0, values(o.current)) == 1 && (flag = true) - if o.current[x]%s == 0 && flag + if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1 η = max(η * decay^(s / n), o.clip) o.eta = η end @@ -298,11 +297,8 @@ mutable struct WeightDecay end WeightDecay() = WeightDecay(0) + function update!(o::WeightDecay, x, Δ) wd = o.wd @. Δ += wd * x end - -DescentWeightDecay(η = 1, wd = 0) = Optimiser(WeightDecay(wd), Descent(η)) - -update!(opt::Function, ps) = opt() diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 9fe459f6..28bdf27b 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -64,12 +64,6 @@ Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. function train!(loss, ps, data, opt; cb = () -> ()) cb = runall(cb) opt = runall(opt) - opt = try - opt() - opt.opt - catch - opt - end @progress for d in data try l = loss(d...) diff --git a/test/optimise.jl b/test/optimise.jl index 78510a94..98d06edb 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -44,8 +44,9 @@ end l = param(1) Flux.train!(() -> (sleep(0.1); i += 1; l), + (), Iterators.repeated((), 100), - () -> (), + Descent(), cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1)) @test 3 < i < 50 From 4a54d30cbf364988cb31a9a8fc06ba74fda93e05 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 31 Oct 2018 15:30:30 +0000 Subject: [PATCH 20/23] correct SGD deprecation --- src/Flux.jl | 6 +++--- src/optimise/Optimise.jl | 2 +- src/optimise/deprecations.jl | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 7c72cbbc..d285b5a9 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -19,9 +19,9 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs -export Descent, ADAM, Momentum, Nesterov, RMSProp, - ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, - ADAMW, InvDecay, ExpDecay, WeightDecay +export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, + ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, + ADAMW, InvDecay, ExpDecay, WeightDecay include("utils.jl") include("onehot.jl") diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index cf12a3c3..5bb38d1e 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,7 @@ module Optimise export train!, - Descent, ADAM, Momentum, Nesterov, RMSProp, + SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, InvDecay, ExpDecay, WeightDecay, stop, Optimiser diff --git a/src/optimise/deprecations.jl b/src/optimise/deprecations.jl index 8529799f..d04a5447 100644 --- a/src/optimise/deprecations.jl +++ b/src/optimise/deprecations.jl @@ -12,8 +12,8 @@ function updaterule(opt, ps) end end -function Descent(params::AbstractArray, η = 0.1; decay = 0.) - depwarn("Descent(params) is deprecated; use Descent(η::Float64) instead", :Descent) +function SGD(params::AbstractArray, η = 0.1; decay = 0.) + depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD) ps = params opt = Descent(η) From 554c4c7c7ac3be1c7e77b1a7693bf905122e13de Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 31 Oct 2018 15:50:08 +0000 Subject: [PATCH 21/23] return Params from params --- src/tracker/back.jl | 21 +++++++++++++++++---- src/tracker/idset.jl | 1 + src/treelike.jl | 2 +- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/tracker/back.jl b/src/tracker/back.jl index e5a84a71..2be772b0 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -66,15 +66,28 @@ end # Out-of-place gradients struct Params - params::IdSet - Params(xs) = new(IdSet(xs)) + order::Vector{Any} + params::IdSet{Any} + Params() = new([], IdSet()) end -@forward Params.params Base.iterate, Base.length +@forward Params.order Base.iterate, Base.length + +function Base.push!(ps::Params, x) + if !(x in ps.params) + push!(ps.order, x) + push!(ps.params, x) + end + return ps +end + +Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps) + +Params(xs) = push!(Params(), xs...) function Base.show(io::IO, ps::Params) print(io, "Params([") - join(io, ps.params, ", ") + join(io, ps.order, ", ") print(io, "])") end diff --git a/src/tracker/idset.jl b/src/tracker/idset.jl index 62570c99..372e262a 100644 --- a/src/tracker/idset.jl +++ b/src/tracker/idset.jl @@ -7,6 +7,7 @@ Base.eltype(::IdSet{T}) where T = T IdSet() = IdSet{Any}() +Base.push!(s::IdSet) = s Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s) Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s) Base.in(x, s::IdSet) = haskey(s.dict, x) diff --git a/src/treelike.jl b/src/treelike.jl index 3d83d448..ae94590b 100644 --- a/src/treelike.jl +++ b/src/treelike.jl @@ -40,7 +40,7 @@ function prefor(f, x; seen = IdSet()) end function params(m) - ps = [] + ps = Params() prefor(p -> Tracker.istracked(p) && Tracker.isleaf(p) && !any(p′ -> p′ === p, ps) && push!(ps, p), From 46049b9f4498544d8b62c69713357e1baae0a5d0 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 31 Oct 2018 16:08:18 +0000 Subject: [PATCH 22/23] tweak update rule --- src/optimise/deprecations.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/optimise/deprecations.jl b/src/optimise/deprecations.jl index d04a5447..40c695b6 100644 --- a/src/optimise/deprecations.jl +++ b/src/optimise/deprecations.jl @@ -3,14 +3,7 @@ using Base: depwarn check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay)) # legacy update rule -function updaterule(opt, ps) - () -> begin - for p in ps - delta = update!(opt, p.data, p.grad) - p.data .-= delta - end - end -end +updaterule(opt, ps) = () -> update!(p, ps) function SGD(params::AbstractArray, η = 0.1; decay = 0.) depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD) From b05cd41c99bcee7aae18efa655a1d6d413deb07d Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 31 Oct 2018 16:26:14 +0000 Subject: [PATCH 23/23] require 1.0 --- .travis.yml | 1 - REQUIRE | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index b26597e9..d1fd28ad 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,6 @@ os: - linux # - osx julia: - - 0.7 - 1.0 - nightly # uncomment the following lines to override the default test script diff --git a/REQUIRE b/REQUIRE index ad3306d6..feec31c3 100644 --- a/REQUIRE +++ b/REQUIRE @@ -1,4 +1,4 @@ -julia 0.7 +julia 1.0 Juno MacroTools 0.3.3 NNlib