Merge branch 'master' into ed/diagm-pair

This commit is contained in:
Mike J Innes 2018-11-05 11:51:29 +00:00
commit d0e4fbb1e0
17 changed files with 555 additions and 309 deletions

View File

@ -4,7 +4,6 @@ os:
- linux - linux
# - osx # - osx
julia: julia:
- 0.7
- 1.0 - 1.0
- nightly - nightly
# uncomment the following lines to override the default test script # uncomment the following lines to override the default test script

View File

@ -1,4 +1,4 @@
julia 0.7 julia 1.0
Juno Juno
MacroTools 0.3.3 MacroTools 0.3.3
NNlib NNlib

View File

@ -19,8 +19,9 @@ export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl") include("optimise/Optimise.jl")
using .Optimise using .Optimise
using .Optimise: @epochs using .Optimise: @epochs
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, InvDecay, ExpDecay, WeightDecay
include("utils.jl") include("utils.jl")
include("onehot.jl") include("onehot.jl")

View File

@ -1,23 +1,12 @@
module Optimise module Optimise
export train!, export train!,
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
struct Param{T}
x::T
Δ::T
end
Param(x::AbstractArray) = Param(x, zero(x))
include("optimisers.jl") include("optimisers.jl")
include("interface.jl")
include("train.jl") include("train.jl")
include("deprecations.jl")
using Flux.Tracker: TrackedArray
Param(x::TrackedArray) = Param(x.data, x.grad)
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
end end

View File

@ -0,0 +1,125 @@
using Base: depwarn
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
# legacy update rule
updaterule(opt, ps) = () -> update!(p, ps)
function SGD(params::AbstractArray, η = 0.1; decay = 0.)
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
ps = params
opt = Descent(η)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function Momentum(params::AbstractArray, η = 0.01; ρ = 0.9, decay = 0.)
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
ps = params
opt = Momentum(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function Nesterov(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
ps = params
opt = Nesterov(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function RMSProp(params::AbstractArray, η = 0.001; ρ = 0.9, decay = 0.)
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
ps = params
opt = RMSProp(η, ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
ps = params
β = (β1, β2)
opt = ADAM(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAGrad(params::AbstractArray, η::Float64 = 0.1; decay = 0.)
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
ps = params
opt = ADAGrad(η)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADADelta(params::AbstractArray, ρ::Float64 = 0.9; decay = 0.)
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
ps = params
opt = ADADelta(ρ)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function AdaMax(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
ps = params
β = (β1, β2)
opt = AdaMax(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function AMSGrad(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
ps = params
β = (β1, β2)
opt = AMSGrad(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function NADAM(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
ps = params
β = (β1, β2)
opt = NADAM(η, β)
opt = check_decay(opt, decay)
updaterule(opt, ps)
end
function ADAMW(params::AbstractArray, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
ps = params
β = (β1, β2)
opt = ADAMW(η, β)
opt = check_decay(opt, decay)
decay != 0 && (opt = Optimiser(opt, WeightDecay(decay)))
updaterule(opt, ps)
end
# Old training loop
struct OldOptimiser
func
end
update!(opt::OldOptimiser, ps) = opt.func()
# Train function
function train!(loss, data, opt; cb = () -> ())
depwarn("train!(loss, data, opt) is deprecated; use train!(loss, params, data, opt) instead", :train!)
train!(loss, (), data, OldOptimiser(opt); cb = cb)
end

View File

@ -1,110 +0,0 @@
call(f, xs...) = f(xs...)
# note for optimisers: set to zero
# p.Δ at the end of the weights update
function optimiser(ps, fs...)
ps = [Param(p) for p in ps]
fs = map(ps) do p
os = map(f -> f(p), fs)
() -> foreach(call, os)
end
() -> foreach(call, fs)
end
"""
SGD(params, η = 0.1; decay = 0)
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
Supports inverse decaying learning rate if the `decay` argument is provided.
"""
SGD(ps, η = 0.1; decay = 0) =
optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η))
"""
Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay.
"""
Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1))
"""
Nesterov(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay.
"""
Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1))
"""
RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0)
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
optimiser. Parameters other than learning rate don't need tuning. Often a good
choice for recurrent networks.
"""
RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
"""
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAMW((params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
"""
ADAMW(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->descentweightdecay(p,1,decay))
"""
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
the -norm.
"""
AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning.
"""
ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning.
"""
ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1))
"""
AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
tuning.
"""
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
"""
NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
than learning rate don't need tuning.
"""
NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))

View File

@ -1,130 +1,304 @@
function descent(p::Param, η::Real) using Flux
function () using Base: @get!
@. p.x -= η * p.Δ using MacroTools: @forward
@. p.Δ = 0
end const ϵ = 1e-8
# TODO: should use weak refs
"""
Descent(η)
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
"""
mutable struct Descent
eta::Float64
end end
# Ref: https://arxiv.org/abs/1711.05101.pdf Descent() = Descent(0.1)
function descentweightdecay(p::Param, η::Real, γ::Real)
function () function update!(o::Descent, x, Δ)
@. p.x = p.x - η * (p.Δ + γ * p.x) Δ .*= o.eta
@. p.Δ = 0
end
end end
function momentum(p::Param, ρ, η) """
v = zero(p.x) Momentum(params, η = 0.01; ρ = 0.9)
function ()
@. v = ρ * v - η * p.Δ Gradient descent with learning rate `η` and momentum `ρ`.
@. p.Δ = -v """
end mutable struct Momentum
eta::Float64
rho::Float64
velocity::IdDict
end end
# Ref. https://arxiv.org/pdf/1212.0901.pdf Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
function nesterov(p::Param, ρ, η)
v = zero(p.x) function update!(o::Momentum, x, Δ)
function () η, ρ = o.eta, o.rho
d = @. ρ^2 * v - (1+ρ) * η * p.Δ v = get!(o.velocity, x, zero(x))::typeof(x)
@. v = ρ*v - η*p.Δ @. v = ρ * v - η * Δ
@. p.Δ = -d @. Δ = -v
end
end end
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) """
acc = zero(p.x) Nesterov(eta, ρ = 0.9)
function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ^2 Gradient descent with learning rate `η` and Nesterov momentum `ρ`.
@. p.Δ *= η / (acc + ϵ) """
end mutable struct Nesterov
eta::Float64
rho::Float64
velocity::IdDict
end end
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
acc = zero(p.x) .+ ϵ
function () function update!(o::Nesterov, x, Δ)
@. acc += p.Δ^2 η, ρ = o.eta, o.rho
@. p.Δ *= η / (acc + ϵ) v = get!(o.velocity, x, zero(x))::typeof(x)
end d = @. ρ^2 * v - (1+ρ) * η * Δ
@. v = ρ*v - η*Δ
@. Δ = -d
end end
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8) """
acc = zero(p.x) RMSProp(η = 0.001, ρ = 0.9)
Δacc = zero(p.x)
function () [RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
@. acc = ρ * acc + (1 - ρ) * p.Δ^2 optimiser. Parameters other than learning rate don't need tuning. Often a good
@. p.Δ *= (Δacc + ϵ) / (acc + ϵ) choice for recurrent networks.
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2 """
end mutable struct RMSProp
eta::Float64
rho::Float64
acc::IdDict
end end
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
mt = zero(p.x)
vt = zero(p.x) function update!(o::RMSProp, x, Δ)
β1p, β2p = β1, β2 η, ρ = o.eta, o.rho
function () acc = get!(o.acc, x, zero(x))::typeof(x)
@. mt = β1 * mt + (1 - β1) * p.Δ @. acc = ρ * acc + (1 - ρ) * Δ^2
@. vt = β2 * vt + (1 - β2) * p.Δ^2 @. Δ *= η / (acc + ϵ)
@. p.Δ = mt / (1 - β1p) / (vt / (1 - β2p) + ϵ) * η
β1p *= β1
β2p *= β2
end
end end
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) """
mt = zero(p.x) ADAM(η = 0.001, β = (0.9, 0.999))
ut = zero(p.x)
β1p = β1 [ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
function () """
@. mt = β1 * mt + (1 - β1) * p.Δ mutable struct ADAM
@. ut = max(β2 * ut, abs(p.Δ)) eta::Float64
@. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ) beta::Tuple{Float64,Float64}
β1p *= β1 state::IdDict
end
end end
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
mt = zero(p.x)
vt = zero(p.x) .+ ϵ function update!(o::ADAM, x, Δ)
v̂t = zero(p.x) .+ ϵ η, β = o.eta, o.beta
function () mt, vt, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β1 * mt + (1 - β1) * p.Δ @. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 @. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η
o.state[x] = (mt, vt, βp .* β)
return Δ
end
"""
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
the -norm.
"""
mutable struct AdaMax
eta::Float64
beta::Tuple{Float64,Float64}
state::IdDict
end
AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
function update!(o::AdaMax, x, Δ)
η, β = o.eta, o.beta
mt, ut, βp = get!(o.state, x, (zero(x), zero(x), β))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. ut = max(β[2] * ut, abs(Δ))
@. Δ = (η/(1 - βp[1])) * mt/(ut + ϵ)
o.state[x] = (mt, ut, βp .* β)
return Δ
end
"""
ADAGrad(η = 0.1; ϵ = 1e-8)
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning.
"""
mutable struct ADAGrad
eta::Float64
acc::IdDict
end
ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
function update!(o::ADAGrad, x, Δ)
η = o.eta
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
@. acc += Δ^2
@. Δ *= η / (acc + ϵ)
end
"""
ADADelta(ρ = 0.9, ϵ = 1e-8)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning.
"""
mutable struct ADADelta
rho::Float64
state::IdDict
end
ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
function update!(o::ADADelta, x, Δ)
ρ = o.rho
acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
@. acc = ρ * acc + (1 - ρ) * Δ^2
@. Δ *= (Δacc + ϵ) / (acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * Δ^2
return Δ
end
"""
AMSGrad(η = 0.001, β = (0.9, 0.999))
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
tuning.
"""
mutable struct AMSGrad
eta::Float64
beta::Tuple{Float64, Float64}
state::IdDict
end
AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
function update!(o::AMSGrad, x, Δ)
η, β = o.eta, o.beta
mt, vt, v̂t = get!(o.state, x, (fill(ϵ, size(x)), fill(ϵ, size(x)), fill(ϵ, size(x))))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2
@. v̂t = max.(v̂t, vt) @. v̂t = max.(v̂t, vt)
@. p.Δ = η * mt / v̂t @. Δ = η * mt / v̂t
end
end end
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) """
mt = zero(p.x) NADAM(η = 0.001, β = (0.9, 0.999))
vt = zero(p.x)
β1p, β2p = β1, β2 [NADAM](http://cs229.stanford.edu/proj2015/054_report.pdf) optimiser. Parameters don't need
function () tuning.
@. mt = β1 * mt + (1 - β1) * p.Δ """
@. vt = β2 * vt + (1 - β2) * p.Δ^2 mutable struct NADAM
@. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / (vt * β2 / (1 - β2p) + ϵ) * η eta::Float64
β1p *= β1 beta::Tuple{Float64, Float64}
β2p *= β2 state::IdDict
end
end end
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh) NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
function expdecay(p::Param, γ::Real) function update!(o::NADAM, x, Δ)
if γ != 0 η, β = o.eta, o.beta
return () -> p.Δ .+= γ .* p.x β1p, β2p = o.beta
else mt, vt = get!(o.state, x, (zero(x), zero(x)))
return () -> nothing @. mt = β[1] * mt + (1 - β[1]) * Δ
end @. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / (vt * β[2] / (1 - β2p) + ϵ) * η
o.state[x] = (mt, vt, (β1p * β[1], β2p * β[2]))
return Δ
end end
function invdecay(p::Param, γ::Real) """
if γ != 0 ADAMW((η = 0.001, β = (0.9, 0.999), decay = 0)
n = 0
return () -> begin [ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
p.Δ .*= 1 / (1 + γ * n) """
n += 1 ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
Optimiser(ADAM(η, β), WeightDecay(wd))
# Compose optimizers
"""
Optimiser(a, b, c...)
Combine several optimisers into one; each optimiser produces a modified gradient
that will be fed into the next, and this is finally applied to the parameter as
usual.
"""
mutable struct Optimiser
os::Vector{Any}
end end
else
return () -> nothing Optimiser(o...) = Optimiser(Any[o...])
@forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
@forward Optimiser.os Base.iterate
Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...)
function update!(o::Optimiser, x, Δ)
for opt in o.os
Δ = update!(opt, x, Δ)
end end
return Δ
end
mutable struct InvDecay
gamma::Float64
state::IdDict
end
InvDecay(γ = 0.001) = InvDecay(γ, IdDict())
function update!(o::InvDecay, x, Δ)
γ = o.gamma
n = get!(o.state, x, 1)
Δ .*= 1 / (1 + γ * n)
o.state[x] = n + 1
return Δ
end
mutable struct ExpDecay
eta::Float64
decay::Float64
step::Int64
clip::Float64
current::IdDict
end
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict())
function update!(o::ExpDecay, x, Δ)
η, s, decay = o.eta, o.step, o.decay
n = o.current[x] = get(o.current, x, 0) + 1
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
η = max(η * decay^(s / n), o.clip)
o.eta = η
end
@. Δ *= decay
end
mutable struct WeightDecay
wd::Real
end
WeightDecay() = WeightDecay(0)
function update!(o::WeightDecay, x, Δ)
wd = o.wd
@. Δ += wd * x
end end

View File

@ -1,7 +1,16 @@
using Juno using Juno
using Flux.Tracker: back! using Flux.Tracker: data, grad, back!
import Base.depwarn import Base.depwarn
function update!(opt, xs)
for x in xs
Δ = update!(opt, x.data, x.grad)
x.data .-= Δ
Δ .= 0
end
end
# Callback niceties
runall(f) = f runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs) runall(fs::AbstractVector) = () -> foreach(call, fs)
@ -35,7 +44,7 @@ function stop()
end end
""" """
train!(loss, data, opt) train!(model, loss, data, opt)
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt`. backpropagation and calls the optimizer `opt`.
@ -44,7 +53,7 @@ Takes a callback as keyword argument `cb`. For example, this will print "trainin
every 10 seconds: every 10 seconds:
```julia ```julia
Flux.train!(loss, data, opt, Flux.train!(model, loss, data, opt,
cb = throttle(() -> println("training"), 10)) cb = throttle(() -> println("training"), 10))
``` ```
@ -52,14 +61,14 @@ The callback can return `:stop` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
""" """
function train!(loss, data, opt; cb = () -> ()) function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb) cb = runall(cb)
opt = runall(opt) opt = runall(opt)
@progress for d in data @progress for d in data
try try
l = loss(d...) l = loss(d...)
@interrupts back!(l) @interrupts back!(l)
opt() update!(opt, ps)
if cb() == :stop if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop) depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break break

View File

@ -68,9 +68,9 @@ end
include("idset.jl") include("idset.jl")
include("back.jl") include("back.jl")
include("scalar.jl")
include("array.jl")
include("numeric.jl") include("numeric.jl")
include("lib/real.jl")
include("lib/array.jl")
""" """
hook(f, x) -> x hook(f, x) -> x

View File

@ -19,62 +19,78 @@ function scan(x)
return return
end end
function back_(c::Call, Δ) function back_(c::Call, Δ, once)
Δs = c.func(Δ) Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) || (Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))") error("Gradient is not a tuple of length $(length(c.args))")
foreach(back, c.args, data.(Δs)) foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
end end
back_(::Call{Nothing}, Δ) = nothing back_(::Call{Nothing}, Δ, once) = nothing
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
accum!(x, Δ) = x .+ Δ accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ) accum!(x::AbstractArray, Δ) = (x .+= Δ)
function back(x::Tracked, Δ) function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return) x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1 ref = x.ref -= 1
if ref > 0 || isdefined(x, :grad) grad = if isdefined(x, :grad)
if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ) x.grad = accum!(x.grad, Δ)
else elseif ref > 0
x.grad = Δ x.grad = Δ
end
ref == 0 && back_(x.f, x.grad)
else else
ref == 0 && back_(x.f, Δ) Δ
end
if ref == 0
back_(x.f, grad, once)
once && !x.isleaf && (x.f = Call(missing, ()))
end end
return return
end end
back(::Nothing, _) = return back(::Nothing, Δ, once) = return
# Interface methods # Interface methods
# TODO: if an error occurs in `back` the refcounts will be broken # TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update. # and `back` will silently fail to update.
# (but only if you re-use intermediate values between passes)
# Refcounts are also probably not safe in some situations (e.g. back called # Refcounts are also probably not safe in some situations (e.g. back called
# from within a backpropagator) # from within a backpropagator)
function back!(x, Δ) function back!(x, Δ; once = true)
istracked(x) || return istracked(x) || return
scan(x) scan(x)
back(tracker(x), Δ) back(tracker(x), Δ, once)
return return
end end
# Out-of-place gradients # Out-of-place gradients
struct Params struct Params
params::IdSet order::Vector{Any}
Params(xs) = new(IdSet(xs)) params::IdSet{Any}
Params() = new([], IdSet())
end end
@forward Params.params Base.iterate, Base.length @forward Params.order Base.iterate, Base.length
function Base.push!(ps::Params, x)
if !(x in ps.params)
push!(ps.order, x)
push!(ps.params, x)
end
return ps
end
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
Params(xs) = push!(Params(), xs...)
function Base.show(io::IO, ps::Params) function Base.show(io::IO, ps::Params)
print(io, "Params([") print(io, "Params([")
join(io, ps.params, ", ") join(io, ps.order, ", ")
print(io, "])") print(io, "])")
end end
@ -91,12 +107,12 @@ Grads() = Grads(IdDict())
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps)) Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x] Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x) function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x") istracked(x) || error("Object not tracked: $x")
g[tracker(x)] g[tracker(x)]
end end
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ) function back_(g::Grads, c::Call, Δ)

View File

@ -7,6 +7,7 @@ Base.eltype(::IdSet{T}) where T = T
IdSet() = IdSet{Any}() IdSet() = IdSet{Any}()
Base.push!(s::IdSet) = s
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s) Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s) Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
Base.in(x, s::IdSet) = haskey(s.dict, x) Base.in(x, s::IdSet) = haskey(s.dict, x)

View File

@ -82,6 +82,17 @@ Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
end end
end end
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
@grad function view(x::AbstractArray, inds...)
view(data(x), inds...), function (Δ)
grad_output = zero(x)
subgrad = view(grad_output, inds...)
subgrad[:] = data(Δ)
(nobacksies(:view, grad_output), map(_->nothing, inds)...)
end
end
Base.:-(xs::TrackedArray) = track(-, xs) Base.:-(xs::TrackedArray) = track(-, xs)
@grad -(xs) = -data(xs), Δ -> (-Δ,) @grad -(xs) = -data(xs), Δ -> (-Δ,)
@ -434,6 +445,7 @@ end
using Requires using Requires
# https://github.com/FluxML/Flux.jl/issues/353 # https://github.com/FluxML/Flux.jl/issues/353
if VERSION < v"1.1.0-DEV.548"
@init Requires.isprecompiling() || @eval Base.Broadcast begin @init Requires.isprecompiling() || @eval Base.Broadcast begin
function flatten(bc::Broadcasted{Style}) where {Style} function flatten(bc::Broadcasted{Style}) where {Style}
isflat(bc) && return bc isflat(bc) && return bc
@ -459,3 +471,4 @@ using Requires
end end
end end
end end
end

View File

@ -10,10 +10,10 @@ tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x))) track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
function back!(x::TrackedReal) function back!(x::TrackedReal; once = true)
isinf(x) && error("Loss is Inf") isinf(x) && error("Loss is Inf")
isnan(x) && error("Loss is NaN") isnan(x) && error("Loss is NaN")
return back!(x, 1) return back!(x, 1, once = once)
end end
function Base.show(io::IO, x::TrackedReal) function Base.show(io::IO, x::TrackedReal)
@ -123,8 +123,8 @@ function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1]) foreach(scan, c.args[1])
end end
function back_(c::Call{typeof(collect)}, Δ) function back_(c::Call{typeof(collect)}, Δ, once)
foreach(back, c.args[1], data(Δ)) foreach((x, d) -> back(x, d, once), c.args[1], data(Δ))
end end
function back_(g::Grads, c::Call{typeof(collect)}, Δ) function back_(g::Grads, c::Call{typeof(collect)}, Δ)

View File

@ -40,7 +40,7 @@ function prefor(f, x; seen = IdSet())
end end
function params(m) function params(m)
ps = [] ps = Params()
prefor(p -> prefor(p ->
Tracker.istracked(p) && Tracker.isleaf(p) && Tracker.istracked(p) && Tracker.isleaf(p) &&
!any(p -> p === p, ps) && push!(ps, p), !any(p -> p === p, ps) && push!(ps, p),

View File

@ -147,9 +147,9 @@ function jacobian(m,x)
n = length(x) n = length(x)
J = Matrix{eltype(x)}(undef,n,k) J = Matrix{eltype(x)}(undef,n,k)
for i = 1:k for i = 1:k
Flux.back!(y[i]) # Populate gradient accumulator Flux.back!(y[i], once = false) # Populate gradient accumulator
J[:,i] = xp.grad J[:,i] = xp.grad
xp.grad .*= 0 # Reset gradient accumulator xp.grad .= 0 # Reset gradient accumulator
end end
J' J'
end end

View File

@ -3,14 +3,37 @@ using Flux.Tracker
using Test using Test
@testset "Optimise" begin @testset "Optimise" begin
w = randn(10, 10) w = randn(10, 10)
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM] @testset for Opt in [ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
w = param(randn(10, 10)) w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w]) opt = Opt(0.001)
if opt isa Descent || opt isa ADAGrad
opt = Opt(0.1)
end
if opt isa ADADelta
opt = Opt(0.9)
end
for t = 1: 10^5 for t = 1: 10^5
l = loss(rand(10)) l = loss(rand(10))
back!(l) back!(l)
opt() delta = Optimise.update!(opt, w.data, w.grad)
w.data .-= delta
end
@test Flux.mse(w, w) < 0.01
end
end
@testset "Optimiser" begin
w = randn(10, 10)
@testset for Opt in [InvDecay, WeightDecay, ExpDecay]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Optimiser(Opt(), ADAM(0.001))
for t = 1:10^5
l = loss(rand(10))
back!(l)
delta = Optimise.update!(opt, w.data, w.grad)
w.data .-= delta
end end
@test Flux.mse(w, w) < 0.01 @test Flux.mse(w, w) < 0.01
end end
@ -21,9 +44,10 @@ end
l = param(1) l = param(1)
Flux.train!(() -> (sleep(0.1); i += 1; l), Flux.train!(() -> (sleep(0.1); i += 1; l),
(),
Iterators.repeated((), 100), Iterators.repeated((), 100),
()->(), Descent(),
cb = Flux.throttle(() -> (i > 3 && stop()), 1)) cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
@test 3 < i < 50 @test 3 < i < 50
end end

View File

@ -33,6 +33,11 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5)) @test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
@test gradtest(x -> x', rand(5)) @test gradtest(x -> x', rand(5))
@testset "indexing & slicing" begin
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
end
function promotiontest(f, A, B, C) function promotiontest(f, A, B, C)
r0 = f(A, B, C) r0 = f(A, B, C)
r1 = f(param(A), B, C) r1 = f(param(A), B, C)
@ -232,10 +237,10 @@ end
@testset "Intermediates" begin @testset "Intermediates" begin
x = param([1]) x = param([1])
l = sum((x .+ x).^2) l = sum((x .+ x).^2)
Flux.back!(l) Flux.back!(l, once = false)
@test x.grad == [8] @test x.grad == [8]
x.grad .= 0 x.grad .= 0
Flux.back!(l) Flux.back!(l, once = false)
@test x.grad == [8] @test x.grad == [8]
end end