Merge branch 'master' into HEAD

This commit is contained in:
Mike J Innes 2017-12-08 18:31:55 +00:00
commit 1d916c81b5
11 changed files with 132 additions and 72 deletions

View File

@ -8,11 +8,11 @@ using Juno, Requires
using Lazy: @forward
export Chain, Dense, RNN, LSTM, GRU, Dropout, LayerNorm,
SGD, ADAM, Momentum, Nesterov,
SGD, ADAM, Momentum, Nesterov, AMSGrad,
param, params, mapleaves
using NNlib
export σ, relu, leakyrelu, elu, swish, softmax
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax
include("tracker/Tracker.jl")
using .Tracker

View File

@ -33,7 +33,7 @@ function rawdict()
filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n"))))
end
validword(s) = ismatch(r"^[\w-\.]+$", s)
validword(s) = ismatch(r"^[\w\-\.]+$", s)
cmudict() = filter((s, ps) -> validword(s), rawdict())

View File

@ -151,9 +151,6 @@ for a good overview of the internals.
"""
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
# GRU
struct GRUCell{D1,D2,V}

View File

@ -1,15 +1,17 @@
using NNlib: log_fast
# Cost functions
mse(, y) = sum(( .- y).^2)/length(y)
crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat) =
-sum(y .* log.()) / size(y, 2)
-sum(y .* log_fast.()) / size(y, 2)
@deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
logŷ = logŷ .- maximum(logŷ, 1)
ypred = logŷ .- log.(sum(exp.(logŷ), 1))
ypred = logŷ .- log_fast.(sum(exp.(logŷ), 1))
-sum(y .* ypred) / size(y, 2)
end

View File

@ -42,7 +42,14 @@ function onehot(l, labels)
OneHotVector(i, length(labels))
end
onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls])
function onehot(l, labels, unk)
i = findfirst(labels, l)
i > 0 || return onehot(unk, labels)
OneHotVector(i, length(labels))
end
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
argmax(y::AbstractVector, labels = 1:length(y)) =
labels[findfirst(y, maximum(y))]

View File

@ -1,7 +1,7 @@
module Optimise
export update!, params, train!,
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
struct Param{T}
x::T

View File

@ -1,5 +1,7 @@
call(f, xs...) = f(xs...)
# note for optimisers: set to zero
# p.Δ at the end of the weigths update
function optimiser(ps, fs...)
ps = [Param(p) for p in ps]
fs = map(ps) do p
@ -10,64 +12,73 @@ function optimiser(ps, fs...)
end
"""
SGD(params, η = 1; decay = 0)
SGD(params, η = 0.1; decay = 0)
Classic gradient descent optimiser. For each parameter `p` and its
gradient `δp`, this runs `p -= η*δp`.
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`.
Supports decayed learning rate decay if the `decay` argument is provided.
Supports inverse decaying learning rate if the `decay` argument is provided.
"""
SGD(ps, η = 1; decay = 0) =
optimiser(ps, p -> invdecay(p, decay), p -> descent(p, η))
SGD(ps, η = 0.1; decay = 0) =
optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η))
"""
Momentum(params, ρ, decay = 0)
Momentum(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with momentum `ρ` and optional learning rate decay.
SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay.
"""
Momentum(ps, ρ; decay = 0) =
optimiser(ps, p -> momentum(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1))
"""
Nesterov(params, ρ, decay = 0)
Nesterov(params, η = 0.01; ρ = 0.9, decay = 0)
SGD with Nesterov momentum `ρ` and optional learning rate decay.
SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay.
"""
Nesterov(ps, ρ; decay = 0) =
optimiser(ps, p -> nesterov(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) =
optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1))
"""
RMSProp(params; η = 0.001, ρ = 0.9, ϵ = 1e-8, decay = 0)
RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0)
[RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
optimiser. Parameters other than learning rate don't need tuning. Often a good
choice for recurrent networks.
"""
RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p -> rmsprop(p; η = η, ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAM(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAM](https://arxiv.org/abs/1412.6980v8) optimiser.
"""
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> adam(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADAGrad(params; η = 0.01, ϵ = 1e-8, decay = 0)
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
Parameters don't need tuning.
"""
ADAGrad(ps; η = 0.01, ϵ = 1e-8, decay = 0) =
optimiser(ps, p -> adagrad(p; η = η, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
"""
ADADelta(params; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0)
ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0)
[ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need
tuning.
"""
ADADelta(ps; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) =
optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) =
optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1))
"""
AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need
tuning.
"""
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))

View File

@ -1,74 +1,97 @@
function descent(p::Param, η::Real)
function ()
p.x .-= p.Δ .* η
p.Δ .= 0
@. p.x -= η * p.Δ
@. p.Δ = 0
end
end
function momentum(p::Param, ρ::Real)
mo = zeros(p.x)
() -> p.Δ .= mo .= ρ .* mo .+ p.Δ
end
function nesterov(p::Param, ρ::Real)
mo = zeros(p.x)
function momentum(p::Param, ρ, η)
v = zeros(p.x)
function ()
mo .= ρ .* mo .+ p.Δ
p.Δ .= ρ .* mo .+ p.Δ
@. v = ρ * v - η * p.Δ
@. p.Δ = -v
end
end
function clip(p::Param, thresh::Real)
() -> clamp!(p.Δ, -thresh, thresh)
end
function weightdecay(p::Param, γ::Real)
() -> p.Δ .+= γ .* p.x
end
function invdecay(p::Param, γ::Real)
n = 0
# Ref. https://arxiv.org/pdf/1212.0901.pdf
function nesterov(p::Param, ρ, η)
v = zeros(p.x)
function ()
p.Δ .*= 1 / (1 + γ * n)
n += 1
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
@. v = ρ*v - η*p.Δ
@. p.Δ = -d
end
end
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ
acc = zeros(p.x)
function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2
@. p.Δ *= η / acc
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= η / (acc + ϵ)
end
end
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ
function ()
@. acc += p.Δ ^ 2
@. acc += p.Δ^2
@. p.Δ *= η / acc
end
end
function adadelta(p::Param; ρ::Real = 0.95, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ
Δacc = zeros(p.x) .+ ϵ
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x)
Δacc = zeros(p.x)
function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2
@. p.Δ *= Δacc / acc
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ ^ 2
end
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= (Δacc + ϵ) / (acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2
end
end
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
vt = zeros(p.x) .+ ϵ
vt = zeros(p.x)
β1p, β2p = β1, β2
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
@. p.Δ = (1 - β2p) / (1 - β1p) * mt / vt * η
@. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = mt / (1 - β1p) / ((vt / (1 - β2p)) + ϵ) * η
β1p *= β1
β2p *= β2
end
end
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
vt = zeros(p.x) .+ ϵ
v̂t = zeros(p.x) .+ ϵ
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
@. v̂t = max.(v̂t, vt)
@. p.Δ = η * mt / v̂t
end
end
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
function expdecay(p::Param, γ::Real)
if γ != 0
return () -> p.Δ .+= γ .* p.x
else
return () -> nothing
end
end
function invdecay(p::Param, γ::Real)
if γ != 0
n = 0
return () -> begin
p.Δ .*= 1 / (1 + γ * n)
n += 1
end
else
return () -> nothing
end
end

View File

@ -40,7 +40,7 @@ TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
isleaf(x::TrackedArray) = x.f == Call(nothing)
param(xs) = TrackedArray(AbstractFloat.(xs))
param(xs) = TrackedArray(map(x -> AbstractFloat(x), xs))
param(xs::Real) = param(fill(xs))
istracked(x::TrackedArray) = true
@ -58,6 +58,7 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
# TODO decide if keeping both data and value. The problem is TrackedScalar
value(x) = x
value(x::TrackedArray) = data(x)
value(x::TrackedScalar) = data(x)[]
@ -69,6 +70,7 @@ Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(x)
Base.isless(x::TrackedScalar, y) = isless(value(x), y)
Base.isless(x, y::TrackedScalar) = isless(x, value(y))
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
Base.isapprox(x::TrackedScalar, y; kws...) = isapprox(x.data[], y; kws...)
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}")

17
test/optimise.jl Normal file
View File

@ -0,0 +1,17 @@
using Flux.Optimise
using Flux.Tracker
@testset "Optimise" begin
w = randn(10, 10)
for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w])
for t=1:10^5
l = loss(rand(10))
back!(l)
opt()
end
@test Flux.mse(w, w) < 0.01
end
end

View File

@ -5,5 +5,6 @@ using Flux, Base.Test
include("utils.jl")
include("tracker.jl")
include("layers/normalisation.jl")
include("optimise.jl")
end