From a2d2d068aa0b60c228b0552de29981c273818ce1 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 31 May 2018 20:29:59 +0100 Subject: [PATCH] initial sketch --- src/Flux.jl | 4 +- src/optimise/Optimise.jl | 18 +-- src/optimise/interface.jl | 110 ------------------ src/optimise/optimisers.jl | 223 +++++++++++++++++++------------------ src/optimise/train.jl | 13 ++- src/tracker/Tracker.jl | 87 +++------------ 6 files changed, 148 insertions(+), 307 deletions(-) delete mode 100644 src/optimise/interface.jl diff --git a/src/Flux.jl b/src/Flux.jl index 614eeaf7..6ec849b0 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -19,8 +19,8 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs -export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, - RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM +export SGD, Descent, ADAM, AdaMax, Momentum, Nesterov, + RMSProp, ADAGrad, ADADelta, AMSGrad include("utils.jl") include("onehot.jl") diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index c4828c9e..d54e4453 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,23 +1,9 @@ module Optimise export train!, - SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, - RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException - -struct Param{T} - x::T - Δ::T -end - -Param(x::AbstractArray) = Param(x, zero(x)) + SGD, Descent, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad include("optimisers.jl") -include("interface.jl") include("train.jl") -using Flux.Tracker: TrackedArray - -Param(x::TrackedArray) = Param(x.data, x.grad) -# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad) - -end +end \ No newline at end of file diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl deleted file mode 100644 index 096e2d87..00000000 --- a/src/optimise/interface.jl +++ /dev/null @@ -1,110 +0,0 @@ -call(f, xs...) = f(xs...) - -# note for optimisers: set to zero -# p.Δ at the end of the weights update -function optimiser(ps, fs...) - ps = [Param(p) for p in ps] - fs = map(ps) do p - os = map(f -> f(p), fs) - () -> foreach(call, os) - end - () -> foreach(call, fs) -end - -""" - SGD(params, η = 0.1; decay = 0) - -Classic gradient descent optimiser with learning rate `η`. -For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`. - -Supports inverse decaying learning rate if the `decay` argument is provided. -""" -SGD(ps, η = 0.1; decay = 0) = - optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η)) - -""" - Momentum(params, η = 0.01; ρ = 0.9, decay = 0) - -SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay. -""" -Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) = - optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1)) - -""" - Nesterov(params, η = 0.01; ρ = 0.9, decay = 0) - -SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay. -""" -Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) = - optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1)) - -""" - RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) - -[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) -optimiser. Parameters other than learning rate don't need tuning. Often a good -choice for recurrent networks. -""" -RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) = - optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) - -""" - ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) - -[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser. -""" -ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) - -""" - ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) - -[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam. -""" -ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay)) - -""" - AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) - -[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on -the ∞-norm. -""" -AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) - -""" - ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0) - -[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. -Parameters don't need tuning. -""" -ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) = - optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) - -""" - ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0) - -[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need -tuning. -""" -ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) = - optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1)) - -""" - AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) - -[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need -tuning. -""" -AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) - -""" - NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) - -[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other -than learning rate don't need tuning. -""" -NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 1f7a7c9c..cfbbcfe9 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -1,130 +1,139 @@ -function descent(p::Param, η::Real) - function () - @. p.x -= η * p.Δ - @. p.Δ = 0 - end +using Flux +using Base: @get! + +const ϵ = 1e-8 + +# TODO: should use weak refs + +""" + Descent(η) + +Classic gradient descent optimiser with learning rate `η`. +For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`. +""" +mutable struct Descent + eta::Float64 end -# Ref: https://arxiv.org/abs/1711.05101.pdf -function descentweightdecay(p::Param, η::Real, γ::Real) - function () - @. p.x = p.x - η * (p.Δ + γ * p.x) - @. p.Δ = 0 - end +function update!(o::Descent, x, Δ) + Δ .*= o.eta end -function momentum(p::Param, ρ, η) - v = zero(p.x) - function () - @. v = ρ * v - η * p.Δ - @. p.Δ = -v - end +""" + Momentum(params, η = 0.01; ρ = 0.9, decay = 0) + +Gradient descent with learning rate `η` and momentum `ρ`. +""" +mutable struct Momentum + eta::Float64 + rho::Float64 + velocity::ObjectIdDict end -# Ref. https://arxiv.org/pdf/1212.0901.pdf -function nesterov(p::Param, ρ, η) - v = zero(p.x) - function () - d = @. ρ^2 * v - (1+ρ) * η * p.Δ - @. v = ρ*v - η*p.Δ - @. p.Δ = -d - end +Momentum(η, ρ = 0.9) = Momentum(η, ρ, ObjectIdDict()) + +function update!(o::Momentum, x, Δ) + η, ρ = o.eta, o.rho + v = @get!(o.velocity, x, zero(x))::typeof(x) + @. v = ρ * v - η * Δ + @. Δ = -v end -function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) - acc = zero(p.x) - function () - @. acc = ρ * acc + (1 - ρ) * p.Δ^2 - @. p.Δ *= η / √(acc + ϵ) - end +""" + Nesterov(eta, ρ = 0.9) + +Gradient descent with learning rate `η` and Nesterov momentum `ρ`. +""" +mutable struct Nesterov + eta::Float64 + rho::Float64 + velocity::ObjectIdDict end -function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) - acc = zero(p.x) .+ ϵ - function () - @. acc += p.Δ^2 - @. p.Δ *= η / √(acc + ϵ) - end +Nesterov(η, ρ = 0.9) = Nesterov(η, ρ, ObjectIdDict()) + +function update!(o::Nesterov, x, Δ) + η, ρ = o.eta, o.rho + v = @get!(o.velocity, x, zero(x))::typeof(x) + d = @. ρ^2 * v - (1+ρ) * η * Δ + @. v = ρ*v - η*Δ + @. Δ = -d end -function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8) - acc = zero(p.x) - Δacc = zero(p.x) - function () - @. acc = ρ * acc + (1 - ρ) * p.Δ^2 - @. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ) - @. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2 - end +""" + RMSProp(η = 0.001, ρ = 0.9) + +[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) +optimiser. Parameters other than learning rate don't need tuning. Often a good +choice for recurrent networks. +""" +mutable struct RMSProp + eta::Float64 + rho::Float64 + acc::ObjectIdDict end -function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) - mt = zero(p.x) - vt = zero(p.x) - β1p, β2p = β1, β2 - function () - @. mt = β1 * mt + (1 - β1) * p.Δ - @. vt = β2 * vt + (1 - β2) * p.Δ^2 - @. p.Δ = mt / (1 - β1p) / √(vt / (1 - β2p) + ϵ) * η - β1p *= β1 - β2p *= β2 - end +RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, ObjectIdDict()) + +function update!(o::RMSProp, x, Δ) + η, ρ = o.eta, o.rho + acc = @get!(o.acc, x, zero(x))::typeof(x) + @. acc = ρ * acc + (1 - ρ) * Δ^2 + @. Δ *= η / (√acc + ϵ) end -function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) - mt = zero(p.x) - ut = zero(p.x) - β1p = β1 - function () - @. mt = β1 * mt + (1 - β1) * p.Δ - @. ut = max(β2 * ut, abs(p.Δ)) - @. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ) - β1p *= β1 - end +""" + ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + +[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser. +""" +mutable struct ADAM + eta::Float64 + beta::Tuple{Float64,Float64} + state::ObjectIdDict end -function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) - mt = zero(p.x) - vt = zero(p.x) .+ ϵ - v̂t = zero(p.x) .+ ϵ - function () - @. mt = β1 * mt + (1 - β1) * p.Δ - @. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 - @. v̂t = max.(v̂t, vt) - @. p.Δ = η * mt / √v̂t - end +ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, ObjectIdDict()) + +function update!(o::ADAM, x, Δ) + η, β = o.eta, o.beta + mt, vt, βp = @get!(o.state, x, (zero(x), zero(x), β)) + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. vt = β[2] * vt + (1 - β[2]) * Δ^2 + @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η + o.state[x] = (mt, vt, βp .* β) end -function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) - mt = zero(p.x) - vt = zero(p.x) - β1p, β2p = β1, β2 - function () - @. mt = β1 * mt + (1 - β1) * p.Δ - @. vt = β2 * vt + (1 - β2) * p.Δ^2 - @. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / √(vt * β2 / (1 - β2p) + ϵ) * η - β1p *= β1 - β2p *= β2 - end -end +# """ +# AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) +# +# [AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on +# the ∞-norm. +# """ -clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh) +# """ +# ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0) +# +# [ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. +# Parameters don't need tuning. +# """ -function expdecay(p::Param, γ::Real) - if γ != 0 - return () -> p.Δ .+= γ .* p.x - else - return () -> nothing - end -end +# """ +# ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0) +# +# [ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need +# tuning. +# """ -function invdecay(p::Param, γ::Real) - if γ != 0 - n = 0 - return () -> begin - p.Δ .*= 1 / (1 + γ * n) - n += 1 - end - else - return () -> nothing - end -end +# """ +# AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) +# +# [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need +# tuning. +# """ + +# struct Optimiser +# os::Vector{Any} +# end + +# TODO: decay diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 09893873..85c402e6 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,7 +1,16 @@ using Juno -using Flux.Tracker: back! -import Base.depwarn +using Flux.Tracker: data, grad, back! +function update!(opt, xs) + for x in xs + x, Δ = data(x), grad(x) + update!(opt, x, Δ) + x .-= Δ + Δ .= 0 + end +end + +# Callback niceties runall(f) = f runall(fs::AbstractVector) = () -> foreach(call, fs) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 190837ab..036c0904 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -1,27 +1,23 @@ module Tracker -using MacroTools -using MacroTools: @q, @forward - import Base: == -export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back! +export TrackedArray, TrackedVector, TrackedMatrix, param, back! tracker(x) = nothing istracked(x) = tracker(x) ≠ nothing isleaf(x) = !istracked(x) || isleaf(tracker(x)) +data(x) = istracked(x) ? data(tracker(x)) : x grad(x) = grad(tracker(x)) grad(::Nothing) = nothing -data(x) = x struct Call{F,As<:Tuple} func::F args::As end -Call(f::F, args::T) where {F,T} = Call{F,T}(f, args) -Call() = Call(nothing, ()) +Call(f, args...) = Call{typeof(f),typeof(args)}(f, args) # When deserialising, the object_id changes a::Call == b::Call = a.func == b.func && a.args == b.args @@ -32,86 +28,37 @@ mutable struct Tracked{T} ref::UInt32 f::Call isleaf::Bool + data::T grad::T - Tracked{T}(f::Call) where T = new(0, f, false) - Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad) - Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad) + Tracked{T}(f::Call, data::T) where T = new(0, f, false, data) + Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad) + Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad) end +Tracked(f::Call, x) = Tracked{typeof(x)}(f, x) +Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ) + +track(f::Call, x) = Tracked(f, x) +track(f::Call) = track(f, f()) +track(f, xs...) = track(Call(f, xs...)) + istracked(x::Tracked) = true -isleaf(x::Tracked) = x.f == Call() +isleaf(x::Tracked) = x.f == Call(nothing) +data(x::Tracked) = x.data grad(x::Tracked) = x.grad -track(f::Call, x) = Tracked{typeof(x)}(f) - -function _forward end - -function track(f::F, xs...; kw...) where F - y, back = _forward(f, xs...; kw...) - track(Call(back, tracker.(xs)), y) -end - -macro grad(ex) - @capture(shortdef(ex), (name_(args__) = body_) | - (name_(args__) where {T__} = body_)) || error("Need a function definition") - T == nothing && (T = []) - isexpr(name, :(::)) || (name = :(::typeof($name))) - insert!(args, 1+isexpr(args[1], :parameters) , name) - @q(Tracker._forward($(args...)) where $(T...) = $body) |> esc -end - -function update!(x, Δ) - x.data .+= data(Δ) - tracker(x).grad .= 0 - return x -end - -include("idset.jl") include("back.jl") include("scalar.jl") include("array.jl") include("numeric.jl") -""" - hook(f, x) -> x′ - -Hook into gradient backpropagation. `x` is unmodified, but when backpropagating -`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse -the sign of the gradient applied to `x`.""" -hook(f, x) = istracked(x) ? track(hook, f, x) : x -@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ)) - -""" - checkpoint(f, args...) - -Behaves like `f(args...)`, but avoids storing the intermediate values needed for -calculating gradients. Instead, `f(args...)` will be called again during the -backward pass. This can be used to save memory in larger models. -""" -checkpoint(f, args...) = track(checkpoint, f, args...) - -@grad function checkpoint(f, args...) - data(f(args...)), function (Δ) - y, back = forward(f, args...) - (nothing, back(Δ)...) - end -end - -nobacksies(f, x) = track(nobacksies, f, x) -nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs) -@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f") - param(x::Number) = TrackedReal(float(x)) param(xs::AbstractArray) = TrackedArray(float.(xs)) -@grad identity(x) = data(x), Δ -> (Δ,) -param(x::TrackedReal) = track(identity, x) -param(x::TrackedArray) = track(identity, x) - import NNlib.cudata import Adapt.adapt cudata(x::TrackedArray) = data(x) adapt(T, xs::TrackedArray) = param(adapt(T, data(xs))) -end +end \ No newline at end of file