initial sketch

This commit is contained in:
Mike J Innes 2018-05-31 20:29:59 +01:00 committed by Dhairya Gandhi
parent 53be49b102
commit a2d2d068aa
6 changed files with 148 additions and 307 deletions

View File

@ -19,8 +19,8 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl") include("optimise/Optimise.jl")
using .Optimise using .Optimise
using .Optimise: @epochs using .Optimise: @epochs
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, export SGD, Descent, ADAM, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM RMSProp, ADAGrad, ADADelta, AMSGrad
include("utils.jl") include("utils.jl")
include("onehot.jl") include("onehot.jl")

View File

@ -1,23 +1,9 @@
module Optimise module Optimise
export train!, export train!,
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, SGD, Descent, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
struct Param{T}
x::T
Δ::T
end
Param(x::AbstractArray) = Param(x, zero(x))
include("optimisers.jl") include("optimisers.jl")
include("interface.jl")
include("train.jl") include("train.jl")
using Flux.Tracker: TrackedArray
Param(x::TrackedArray) = Param(x.data, x.grad)
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
end end

View File

@ -1,110 +0,0 @@
call(f, xs...) = f(xs...)
# note for optimisers: set to zero
# p.Δ at the end of the weights update
function optimiser(ps, fs...)
ps = [Param(p) for p in ps]
fs = map(ps) do p
os = map(f -> f(p), fs)
() -> foreach(call, os)
end
() -> foreach(call, fs)
end
"""
SGD(params, η = 0.1; decay = 0)
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
Supports inverse decaying learning rate if the `decay` argument is provided.
"""
SGD(ps, η = 0.1; decay = 0) =
optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η))
"""
Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay.
"""
Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1))
"""
Nesterov(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay.
"""
Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1))
"""
RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0)
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
optimiser. Parameters other than learning rate don't need tuning. Often a good
choice for recurrent networks.
"""
RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
"""
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
"""
ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay))
"""
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
the -norm.
"""
AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning.
"""
ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning.
"""
ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1))
"""
AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
tuning.
"""
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
"""
NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
than learning rate don't need tuning.
"""
NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))

View File

@ -1,130 +1,139 @@
function descent(p::Param, η::Real) using Flux
function () using Base: @get!
@. p.x -= η * p.Δ
@. p.Δ = 0 const ϵ = 1e-8
end
# TODO: should use weak refs
"""
Descent(η)
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
"""
mutable struct Descent
eta::Float64
end end
# Ref: https://arxiv.org/abs/1711.05101.pdf function update!(o::Descent, x, Δ)
function descentweightdecay(p::Param, η::Real, γ::Real) Δ .*= o.eta
function ()
@. p.x = p.x - η * (p.Δ + γ * p.x)
@. p.Δ = 0
end
end end
function momentum(p::Param, ρ, η) """
v = zero(p.x) Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
function ()
@. v = ρ * v - η * p.Δ Gradient descent with learning rate `η` and momentum `ρ`.
@. p.Δ = -v """
end mutable struct Momentum
eta::Float64
rho::Float64
velocity::ObjectIdDict
end end
# Ref. https://arxiv.org/pdf/1212.0901.pdf Momentum(η, ρ = 0.9) = Momentum(η, ρ, ObjectIdDict())
function nesterov(p::Param, ρ, η)
v = zero(p.x) function update!(o::Momentum, x, Δ)
function () η, ρ = o.eta, o.rho
d = @. ρ^2 * v - (1+ρ) * η * p.Δ v = @get!(o.velocity, x, zero(x))::typeof(x)
@. v = ρ*v - η*p.Δ @. v = ρ * v - η * Δ
@. p.Δ = -d @. Δ = -v
end
end end
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) """
acc = zero(p.x) Nesterov(eta, ρ = 0.9)
function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ^2 Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
@. p.Δ *= η / (acc + ϵ) """
end mutable struct Nesterov
eta::Float64
rho::Float64
velocity::ObjectIdDict
end end
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) Nesterov(η, ρ = 0.9) = Nesterov(η, ρ, ObjectIdDict())
acc = zero(p.x) .+ ϵ
function () function update!(o::Nesterov, x, Δ)
@. acc += p.Δ^2 η, ρ = o.eta, o.rho
@. p.Δ *= η / (acc + ϵ) v = @get!(o.velocity, x, zero(x))::typeof(x)
end d = @. ρ^2 * v - (1+ρ) * η * Δ
@. v = ρ*v - η*Δ
@. Δ = -d
end end
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8) """
acc = zero(p.x) RMSProp(η = 0.001, ρ = 0.9)
Δacc = zero(p.x)
function () [RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
@. acc = ρ * acc + (1 - ρ) * p.Δ^2 optimiser. Parameters other than learning rate don't need tuning. Often a good
@. p.Δ *= (Δacc + ϵ) / (acc + ϵ) choice for recurrent networks.
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2 """
end mutable struct RMSProp
eta::Float64
rho::Float64
acc::ObjectIdDict
end end
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, ObjectIdDict())
mt = zero(p.x)
vt = zero(p.x) function update!(o::RMSProp, x, Δ)
β1p, β2p = β1, β2 η, ρ = o.eta, o.rho
function () acc = @get!(o.acc, x, zero(x))::typeof(x)
@. mt = β1 * mt + (1 - β1) * p.Δ @. acc = ρ * acc + (1 - ρ) * Δ^2
@. vt = β2 * vt + (1 - β2) * p.Δ^2 @. Δ *= η / (acc + ϵ)
@. p.Δ = mt / (1 - β1p) / (vt / (1 - β2p) + ϵ) * η
β1p *= β1
β2p *= β2
end
end end
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) """
mt = zero(p.x) ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
ut = zero(p.x)
β1p = β1 [ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
function () """
@. mt = β1 * mt + (1 - β1) * p.Δ mutable struct ADAM
@. ut = max(β2 * ut, abs(p.Δ)) eta::Float64
@. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ) beta::Tuple{Float64,Float64}
β1p *= β1 state::ObjectIdDict
end
end end
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, ObjectIdDict())
mt = zero(p.x)
vt = zero(p.x) .+ ϵ function update!(o::ADAM, x, Δ)
v̂t = zero(p.x) .+ ϵ η, β = o.eta, o.beta
function () mt, vt, βp = @get!(o.state, x, (zero(x), zero(x), β))
@. mt = β1 * mt + (1 - β1) * p.Δ @. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 @. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. v̂t = max.(v̂t, vt) @. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η
@. p.Δ = η * mt / v̂t o.state[x] = (mt, vt, βp .* β)
end
end end
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) # """
mt = zero(p.x) # AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
vt = zero(p.x) #
β1p, β2p = β1, β2 # [AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
function () # the ∞-norm.
@. mt = β1 * mt + (1 - β1) * p.Δ # """
@. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / (vt * β2 / (1 - β2p) + ϵ) * η
β1p *= β1
β2p *= β2
end
end
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh) # """
# ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
#
# [ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
# Parameters don't need tuning.
# """
function expdecay(p::Param, γ::Real) # """
if γ != 0 # ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
return () -> p.Δ .+= γ .* p.x #
else # [ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
return () -> nothing # tuning.
end # """
end
function invdecay(p::Param, γ::Real) # """
if γ != 0 # AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
n = 0 #
return () -> begin # [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
p.Δ .*= 1 / (1 + γ * n) # tuning.
n += 1 # """
end
else # struct Optimiser
return () -> nothing # os::Vector{Any}
end # end
end
# TODO: decay

View File

@ -1,7 +1,16 @@
using Juno using Juno
using Flux.Tracker: back! using Flux.Tracker: data, grad, back!
import Base.depwarn
function update!(opt, xs)
for x in xs
x, Δ = data(x), grad(x)
update!(opt, x, Δ)
x .-= Δ
Δ .= 0
end
end
# Callback niceties
runall(f) = f runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs) runall(fs::AbstractVector) = () -> foreach(call, fs)

View File

@ -1,27 +1,23 @@
module Tracker module Tracker
using MacroTools
using MacroTools: @q, @forward
import Base: == import Base: ==
export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back! export TrackedArray, TrackedVector, TrackedMatrix, param, back!
tracker(x) = nothing tracker(x) = nothing
istracked(x) = tracker(x) nothing istracked(x) = tracker(x) nothing
isleaf(x) = !istracked(x) || isleaf(tracker(x)) isleaf(x) = !istracked(x) || isleaf(tracker(x))
data(x) = istracked(x) ? data(tracker(x)) : x
grad(x) = grad(tracker(x)) grad(x) = grad(tracker(x))
grad(::Nothing) = nothing grad(::Nothing) = nothing
data(x) = x
struct Call{F,As<:Tuple} struct Call{F,As<:Tuple}
func::F func::F
args::As args::As
end end
Call(f::F, args::T) where {F,T} = Call{F,T}(f, args) Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
Call() = Call(nothing, ())
# When deserialising, the object_id changes # When deserialising, the object_id changes
a::Call == b::Call = a.func == b.func && a.args == b.args a::Call == b::Call = a.func == b.func && a.args == b.args
@ -32,82 +28,33 @@ mutable struct Tracked{T}
ref::UInt32 ref::UInt32
f::Call f::Call
isleaf::Bool isleaf::Bool
data::T
grad::T grad::T
Tracked{T}(f::Call) where T = new(0, f, false) Tracked{T}(f::Call, data::T) where T = new(0, f, false, data)
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad) Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad)
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad) Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad)
end end
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ)
track(f::Call, x) = Tracked(f, x)
track(f::Call) = track(f, f())
track(f, xs...) = track(Call(f, xs...))
istracked(x::Tracked) = true istracked(x::Tracked) = true
isleaf(x::Tracked) = x.f == Call() isleaf(x::Tracked) = x.f == Call(nothing)
data(x::Tracked) = x.data
grad(x::Tracked) = x.grad grad(x::Tracked) = x.grad
track(f::Call, x) = Tracked{typeof(x)}(f)
function _forward end
function track(f::F, xs...; kw...) where F
y, back = _forward(f, xs...; kw...)
track(Call(back, tracker.(xs)), y)
end
macro grad(ex)
@capture(shortdef(ex), (name_(args__) = body_) |
(name_(args__) where {T__} = body_)) || error("Need a function definition")
T == nothing && (T = [])
isexpr(name, :(::)) || (name = :(::typeof($name)))
insert!(args, 1+isexpr(args[1], :parameters) , name)
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
end
function update!(x, Δ)
x.data .+= data(Δ)
tracker(x).grad .= 0
return x
end
include("idset.jl")
include("back.jl") include("back.jl")
include("scalar.jl") include("scalar.jl")
include("array.jl") include("array.jl")
include("numeric.jl") include("numeric.jl")
"""
hook(f, x) -> x
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
the sign of the gradient applied to `x`."""
hook(f, x) = istracked(x) ? track(hook, f, x) : x
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))
"""
checkpoint(f, args...)
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
calculating gradients. Instead, `f(args...)` will be called again during the
backward pass. This can be used to save memory in larger models.
"""
checkpoint(f, args...) = track(checkpoint, f, args...)
@grad function checkpoint(f, args...)
data(f(args...)), function (Δ)
y, back = forward(f, args...)
(nothing, back(Δ)...)
end
end
nobacksies(f, x) = track(nobacksies, f, x)
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f")
param(x::Number) = TrackedReal(float(x)) param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs)) param(xs::AbstractArray) = TrackedArray(float.(xs))
@grad identity(x) = data(x), Δ -> (Δ,)
param(x::TrackedReal) = track(identity, x)
param(x::TrackedArray) = track(identity, x)
import NNlib.cudata import NNlib.cudata
import Adapt.adapt import Adapt.adapt