From 669273b008f4967d578dd8e692045973a16dbd9b Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Tue, 17 Oct 2017 17:26:15 +0800 Subject: [PATCH 01/18] layer: implement BatchNorm layer See [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf) --- src/Flux.jl | 2 +- src/layers/normalisation.jl | 69 +++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/src/Flux.jl b/src/Flux.jl index 242c8b1f..ad85efcb 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,7 +7,7 @@ module Flux using Juno, Requires using Lazy: @forward -export Chain, Dense, RNN, LSTM, Dropout, +export BatchNorm, Chain, Dense, RNN, LSTM, Dropout, SGD, ADAM, Momentum, Nesterov, param, params, mapleaves diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index 08c21428..10425d96 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -43,3 +43,72 @@ function (a::Dropout)(x) end _testmode!(a::Dropout, test) = (a.active = !test) + +""" + BatchNorm(dims...; λ = identity, + initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) + +Batch Normalization Layer + +See [Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf) + +In the example of MNIST, +in order to normalize the input of other layer, +put the `BatchNorm` layer before activation function. + +```julia +julia> m = Chain( + Dense(28^2, 64), + BatchNorm(64, λ = relu), + Dense(64, 10), + BatchNorm(10), + softmax) +Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax) +``` +""" +mutable struct BatchNorm{F,V} + λ::F # activation function + β::V # bias + γ::V # scale + μ # moving mean + σ # moving std + ϵ::Float64 + momentum::Float64 + active::Bool +end + +BatchNorm(dims::Integer...; λ = identity, + initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) = + BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., momentum, ϵ, true) + +function (BN::BatchNorm)(x) + if !BN.active + μ = BN.μ + σ = BN.σ + else + m = size(x, 2) # batch size + μ = sum(x, 2) ./ m + σ = sqrt.(sum((x .- μ).^2, 2) ./ (m - 1) .+ BN.ϵ) + + # update moving mean/std + mtm = BN.momentum + BN.μ = mtm .* μ.data .+ (1 - mtm) .* BN.μ + BN.σ = mtm .* σ.data .+ (1 - mtm) .* BN.σ + end + + BN.λ.(BN.γ .* ((x .- μ) ./ σ) .+ BN.β) +end + +children(BN::BatchNorm) = + (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) +mapchildren(f, BN::BatchNorm) = + BatchNorm(λ, f(BN.β), f(BN.γ), BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) + +_testmode!(BN::BatchNorm, test) = (BN.active = !test) + +function Base.show(io::IO, l::BatchNorm) + print(io, "BatchNorm($(join(size(l.β), ", "))") + (l.λ == identity) || print(io, ", λ = $(l.λ)") + print(io, ")") +end From e0201be7704d82eb7d674af753fe3b81beda8ba9 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Mon, 30 Oct 2017 11:34:51 +0800 Subject: [PATCH 02/18] batchnorm: parameterize momentum and epsilon --- src/layers/normalisation.jl | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index 10425d96..a105e1ba 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -67,32 +67,35 @@ julia> m = Chain( Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax) ``` """ -mutable struct BatchNorm{F,V} +mutable struct BatchNorm{F,V,N} λ::F # activation function β::V # bias γ::V # scale μ # moving mean σ # moving std - ϵ::Float64 - momentum::Float64 + ϵ::N + momentum::N active::Bool end BatchNorm(dims::Integer...; λ = identity, initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) = - BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., momentum, ϵ, true) + BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true) function (BN::BatchNorm)(x) if !BN.active μ = BN.μ σ = BN.σ else + T = eltype(x) + + ϵ = T(BN.ϵ) m = size(x, 2) # batch size μ = sum(x, 2) ./ m - σ = sqrt.(sum((x .- μ).^2, 2) ./ (m - 1) .+ BN.ϵ) + σ = sqrt.(sum((x .- μ).^2, 2) ./ (m - 1) .+ ϵ) # update moving mean/std - mtm = BN.momentum + mtm = T(BN.momentum) BN.μ = mtm .* μ.data .+ (1 - mtm) .* BN.μ BN.σ = mtm .* σ.data .+ (1 - mtm) .* BN.σ end @@ -102,7 +105,8 @@ end children(BN::BatchNorm) = (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) -mapchildren(f, BN::BatchNorm) = + +mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) BatchNorm(λ, f(BN.β), f(BN.γ), BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) _testmode!(BN::BatchNorm, test) = (BN.active = !test) From b3356cc6bb91bedf610811031504dfabb1d1bc27 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Mon, 30 Oct 2017 12:57:30 +0800 Subject: [PATCH 03/18] =?UTF-8?q?batchnorm:=20batch=20=CF=83=20correct=20c?= =?UTF-8?q?oefficient?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/layers/normalisation.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index a105e1ba..f7a425b2 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -92,12 +92,12 @@ function (BN::BatchNorm)(x) ϵ = T(BN.ϵ) m = size(x, 2) # batch size μ = sum(x, 2) ./ m - σ = sqrt.(sum((x .- μ).^2, 2) ./ (m - 1) .+ ϵ) + σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ) # update moving mean/std mtm = T(BN.momentum) - BN.μ = mtm .* μ.data .+ (1 - mtm) .* BN.μ - BN.σ = mtm .* σ.data .+ (1 - mtm) .* BN.σ + BN.μ = (1 - mtm) .* BN.μ .+ mtm .* μ.data + BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1) end BN.λ.(BN.γ .* ((x .- μ) ./ σ) .+ BN.β) From ce468434593ada3abd2332a052fb84db709bc4ac Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Mon, 30 Oct 2017 13:24:35 +0800 Subject: [PATCH 04/18] batchnorm: add test cases --- test/layers/normalisation.jl | 40 ++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 5a302a51..e3115f67 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -26,3 +26,43 @@ using Flux: testmode! y = m(x) @test count(a->a == 0, y) == 0 end + +@testset "BatchNorm" begin + let m = BatchNorm(2), x = param([1 2; 3 4; 5 6]') + + @test m.β.data == [0, 0] # initβ(2) + @test m.γ.data == [1, 1] # initγ(2) + # initial m.σ is 1 + # initial m.μ is 0 + @test m.active + + # @test m(x).data ≈ [-1 -1; 0 0; 1 1]' + m(x) + + # julia> x + # 2×3 Array{Float64,2}: + # 1.0 3.0 5.0 + # 2.0 4.0 6.0 + # + # μ of batch will be + # (1. + 3. + 5.) / 3 = 3 + # (2. + 4. + 6.) / 3 = 4 + # + # ∴ update rule with momentum: + # .1 * 3 + 0 = .3 + # .1 * 4 + 0 = .4 + @test m.μ ≈ reshape([0.3, 0.4], 2, 1) + + # julia> .1 .* std(x, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + # 2×1 Array{Float64,2}: + # 1.14495 + # 1.14495 + @test m.σ ≈ .1 .* std(x.data, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + + testmode!(m) + @test !m.active + + x′ = m(x).data + @test x′[1] ≈ (1 - 0.3) / 1.1449489742783179 + end +end From 5253841acce973def6d2fb24e235f1ad63bca8da Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Mon, 30 Oct 2017 13:33:01 +0800 Subject: [PATCH 05/18] batchnorm: update docs --- docs/src/models/layers.md | 1 + src/layers/normalisation.jl | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 5d5d2ee8..98abae3b 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -36,5 +36,6 @@ swish These layers don't affect the structure of the network but may improve training times or reduce overfitting. ```@docs +BatchNorm Dropout ``` diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index f7a425b2..1bd44b62 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -2,8 +2,8 @@ testmode!(m) testmode!(m, false) -Put layers like [`Dropout`](@ref) and `BatchNorm` into testing mode (or back to -training mode with `false`). +Put layers like [`Dropout`](@ref) and [`BatchNorm`](@ref) into testing mode +(or back to training mode with `false`). """ function testmode!(m, val::Bool=true) prefor(x -> _testmode!(x, val), m) @@ -48,7 +48,7 @@ _testmode!(a::Dropout, test) = (a.active = !test) BatchNorm(dims...; λ = identity, initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) -Batch Normalization Layer +Batch Normalization Layer for [`Dense`](@ref) layer. See [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf) @@ -65,6 +65,8 @@ julia> m = Chain( BatchNorm(10), softmax) Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax) + +julia> opt = SGD(params(m), 10) # a crazy learning rate ``` """ mutable struct BatchNorm{F,V,N} From 7f5ba594a958c5623c149d5a2caab64424b63298 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Mon, 30 Oct 2017 13:37:48 +0800 Subject: [PATCH 06/18] batchnorm: more test cases --- test/layers/normalisation.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index e3115f67..118a5700 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -65,4 +65,16 @@ end x′ = m(x).data @test x′[1] ≈ (1 - 0.3) / 1.1449489742783179 end + + # with activation function + let m = BatchNorm(2, λ = σ), x = param([1 2; 3 4; 5 6]') + @test m.active + m(x) + + testmode!(m) + @test !m.active + + x′ = m(x).data + @test x′[1] ≈ σ((1 - 0.3) / 1.1449489742783179) + end end From 477da754285f1aab8719a8cbee698db0619cebb1 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Mon, 30 Oct 2017 13:42:00 +0800 Subject: [PATCH 07/18] batchnorm: fix mapchildren --- src/layers/normalisation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index 1bd44b62..ee606b40 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -109,7 +109,7 @@ children(BN::BatchNorm) = (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) - BatchNorm(λ, f(BN.β), f(BN.γ), BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) + BatchNorm(BN.λ, f(BN.β), f(BN.γ), BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) _testmode!(BN::BatchNorm, test) = (BN.active = !test) From 88bd8a8fbd17139c9d2a5ef01cb575e73f604bf9 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Thu, 2 Nov 2017 13:40:06 +0800 Subject: [PATCH 08/18] batchnorm: make CuArrays happy --- src/layers/normalisation.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index ee606b40..d4a9c94e 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -66,7 +66,7 @@ julia> m = Chain( softmax) Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax) -julia> opt = SGD(params(m), 10) # a crazy learning rate +julia> opt = SGD(params(m), 10, decay = .1) # a crazy learning rate ``` """ mutable struct BatchNorm{F,V,N} @@ -85,6 +85,8 @@ BatchNorm(dims::Integer...; λ = identity, BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true) function (BN::BatchNorm)(x) + λ, γ, β = BN.λ, BN.γ, BN.β + if !BN.active μ = BN.μ σ = BN.σ @@ -102,7 +104,7 @@ function (BN::BatchNorm)(x) BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1) end - BN.λ.(BN.γ .* ((x .- μ) ./ σ) .+ BN.β) + λ.(γ .* ((x .- μ) ./ σ) .+ β) end children(BN::BatchNorm) = From 6c7613e02b2ce102d742c516caa715e4bf67538a Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Thu, 2 Nov 2017 14:20:34 +0800 Subject: [PATCH 09/18] batchnorm: leverage TrackedArray `mean` --- src/layers/normalisation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index d4a9c94e..bd1425d8 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -95,7 +95,7 @@ function (BN::BatchNorm)(x) ϵ = T(BN.ϵ) m = size(x, 2) # batch size - μ = sum(x, 2) ./ m + μ = mean(x, 2) σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ) # update moving mean/std From 13b934c2500b8e39ac24c834079b562057dede5a Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 12 Oct 2017 10:31:38 +0200 Subject: [PATCH 10/18] improve optimizers --- src/data/cmudict.jl | 3 +- src/optimise/interface.jl | 50 +++++++++++----------- src/optimise/optimisers.jl | 85 +++++++++++++++++++++----------------- src/tracker/Tracker.jl | 2 + test/optimise.jl | 19 +++++++++ test/runtests.jl | 1 + 6 files changed, 98 insertions(+), 62 deletions(-) create mode 100644 test/optimise.jl diff --git a/src/data/cmudict.jl b/src/data/cmudict.jl index 88b9c6c0..a23c6a3d 100644 --- a/src/data/cmudict.jl +++ b/src/data/cmudict.jl @@ -33,7 +33,8 @@ function rawdict() filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n")))) end -validword(s) = ismatch(r"^[\w-\.]+$", s) +# validword(s) = ismatch(r"^[\w-\.]+$", s) +validword(s) = ismatch(r"^\[\w-\.\]+$", s) cmudict() = filter((s, ps) -> validword(s), rawdict()) diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl index 0b2a25ae..47b0f62c 100644 --- a/src/optimise/interface.jl +++ b/src/optimise/interface.jl @@ -1,5 +1,7 @@ call(f, xs...) = f(xs...) +# note for optimisers: set to zero +# p.Δ at the end of the weigths update function optimiser(ps, fs...) ps = [Param(p) for p in ps] fs = map(ps) do p @@ -10,64 +12,64 @@ function optimiser(ps, fs...) end """ - SGD(params, η = 1; decay = 0) + SGD(params, η = 0.1; decay = 0) -Classic gradient descent optimiser. For each parameter `p` and its -gradient `δp`, this runs `p -= η*δp`. +Classic gradient descent optimiser with learning rate `η`. +For each parameter `p` and its gradient `δp`, this runs `p -= η*δp`. -Supports decayed learning rate decay if the `decay` argument is provided. +Supports inverse decaying learning rate if the `decay` argument is provided. """ -SGD(ps, η = 1; decay = 0) = - optimiser(ps, p -> invdecay(p, decay), p -> descent(p, η)) +SGD(ps, η = 0.1; decay = 0) = + optimiser(ps, p -> invdecay(p, decay), p -> descent(p,η)) """ - Momentum(params, ρ, decay = 0) + Momentum(params, η = 0.01; ρ = 0.9, decay = 0) -SGD with momentum `ρ` and optional learning rate decay. +SGD with learning rate `η`, momentum `ρ` and optional learning rate inverse decay. """ -Momentum(ps, ρ; decay = 0) = - optimiser(ps, p -> momentum(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1)) +Momentum(ps, η = 0.01; ρ = 0.9, decay = 0) = + optimiser(ps, p->invdecay(p,decay), p->momentum(p, ρ, η), p->descent(p,1)) """ - Nesterov(params, ρ, decay = 0) + Nesterov(params, η = 0.01; ρ = 0.9, decay = 0) -SGD with Nesterov momentum `ρ` and optional learning rate decay. +SGD with learning rate `η`, Nesterov momentum `ρ` and optional learning rate inverse decay. """ -Nesterov(ps, ρ; decay = 0) = - optimiser(ps, p -> nesterov(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1)) +Nesterov(ps, η = 0.01; ρ = 0.9, decay = 0) = + optimiser(ps, p->invdecay(p,decay), p->nesterov(p, ρ, η), p->descent(p,1)) """ - RMSProp(params; η = 0.001, ρ = 0.9, ϵ = 1e-8, decay = 0) + RMSProp(params, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) [RMSProp](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) optimiser. Parameters other than learning rate don't need tuning. Often a good choice for recurrent networks. """ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) = - optimiser(ps, p -> rmsprop(p; η = η, ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) + optimiser(ps, p->rmsprop(p; η=η, ρ=ρ, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) """ - ADAM(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) [ADAM](https://arxiv.org/abs/1412.6980v8) optimiser. """ ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p -> adam(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) + optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) """ - ADAGrad(params; η = 0.01, ϵ = 1e-8, decay = 0) + ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0) [ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. Parameters don't need tuning. """ -ADAGrad(ps; η = 0.01, ϵ = 1e-8, decay = 0) = - optimiser(ps, p -> adagrad(p; η = η, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) +ADAGrad(ps, η = 0.01; ϵ = 1e-8, decay = 0) = + optimiser(ps, p->adagrad(p; η=η, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) """ - ADADelta(params; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) + ADADelta(params; ρ = 0.9, ϵ = 1e-8, decay = 0) [ADADelta](http://arxiv.org/abs/1212.5701) optimiser. Parameters don't need tuning. """ -ADADelta(ps; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) = - optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) +ADADelta(ps; ρ = 0.9, ϵ = 1e-8, decay = 0) = + optimiser(ps, p->adadelta(p; ρ=ρ, ϵ=ϵ), p->descent(p,1)) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index abc54090..7cf271b6 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -1,74 +1,85 @@ function descent(p::Param, η::Real) function () - p.x .-= p.Δ .* η - p.Δ .= 0 + @. p.x -= η * p.Δ + @. p.Δ = 0 end end -function momentum(p::Param, ρ::Real) - mo = zeros(p.x) - () -> p.Δ .= mo .= ρ .* mo .+ p.Δ -end - -function nesterov(p::Param, ρ::Real) - mo = zeros(p.x) +function momentum(p::Param, ρ, η) + v = zeros(p.x) function () - mo .= ρ .* mo .+ p.Δ - p.Δ .= ρ .* mo .+ p.Δ + @. v = ρ * v - η * p.Δ + @. p.Δ = -v end end -function clip(p::Param, thresh::Real) - () -> clamp!(p.Δ, -thresh, thresh) -end - -function weightdecay(p::Param, γ::Real) - () -> p.Δ .+= γ .* p.x -end - -function invdecay(p::Param, γ::Real) - n = 0 +# Ref. https://arxiv.org/pdf/1212.0901.pdf +function nesterov(p::Param, ρ, η) + v = zeros(p.x) function () - p.Δ .*= 1 / (1 + γ * n) - n += 1 + d = @. ρ^2 * v - (1+ρ) * η * p.Δ + @. v = ρ*v - η*p.Δ + @. p.Δ = -d end end function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) - acc = zeros(p.x) .+ ϵ + acc = zeros(p.x) function () - @. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2 - @. p.Δ *= η / √acc + @. acc = ρ * acc + (1 - ρ) * p.Δ^2 + @. p.Δ *= η / (√acc + ϵ) end end function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) acc = zeros(p.x) .+ ϵ function () - @. acc += p.Δ ^ 2 + @. acc += p.Δ^2 @. p.Δ *= η / √acc end end -function adadelta(p::Param; ρ::Real = 0.95, ϵ::Real = 1e-8) - acc = zeros(p.x) .+ ϵ - Δacc = zeros(p.x) .+ ϵ +function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8) + acc = zeros(p.x) + Δacc = zeros(p.x) function () - @. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2 - @. p.Δ *= √Δacc / √acc - @. Δacc = ρ * Δacc + (1 - ρ) * p.Δ ^ 2 - end + @. acc = ρ * acc + (1 - ρ) * p.Δ^2 + @. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ) + @. Δacc = ρ * Δacc + (1 - ρ) * p.Δ^2 + end end function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) mt = zeros(p.x) - vt = zeros(p.x) .+ ϵ + vt = zeros(p.x) β1p, β2p = β1, β2 function () @. mt = β1 * mt + (1 - β1) * p.Δ - @. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 - @. p.Δ = √(1 - β2p) / √(1 - β1p) * mt / √vt * η + @. vt = β2 * vt + (1 - β2) * p.Δ^2 + @. p.Δ = mt / (1 - β1p) / (sqrt(vt / (1 - β2p)) + ϵ) * η β1p *= β1 β2p *= β2 end end + +clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh) + +function expdecay(p::Param, γ::Real) + if γ != 0 + return () -> p.Δ .+= γ .* p.x + else + return () -> nothing + end +end + +function invdecay(p::Param, γ::Real) + if γ != 0 + n = 0 + return () -> begin + p.Δ .*= 1 / (1 + γ * n) + n += 1 + end + else + return () -> nothing + end +end \ No newline at end of file diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 3a64fcb7..57bdc447 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -58,6 +58,7 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) +# TODO decide if keeping both data and value. The problem is TrackedScalar value(x) = x value(x::TrackedArray) = data(x) value(x::TrackedScalar) = data(x)[] @@ -69,6 +70,7 @@ Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(x) Base.isless(x::TrackedScalar, y) = isless(value(x), y) Base.isless(x, y::TrackedScalar) = isless(x, value(y)) Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y)) +Base.isapprox(x::TrackedScalar, y; kws...) = isapprox(x.data[], y; kws...) Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = print(io, "TrackedArray{…,$A}") diff --git a/test/optimise.jl b/test/optimise.jl new file mode 100644 index 00000000..85fd53f9 --- /dev/null +++ b/test/optimise.jl @@ -0,0 +1,19 @@ +using Flux.Optimise +using Flux.Tracker + +@testset "Optimise" begin + loss(x) = sum(x.^2) + η = 0.1 + # RMSProp gets stuck + for OPT in [SGD, Nesterov, Momentum, ADAM, ADAGrad, ADADelta] + x = param(randn(10)) + opt = OPT == ADADelta ? OPT([x]) : OPT([x], η) + for t=1:10000 + l = loss(x) + back!(l) + opt() + l.data[] < 1e-10 && break + end + @test loss(x) ≈ 0. atol=1e-7 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index efd1a462..bdd1f2d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,5 +5,6 @@ using Flux, Base.Test include("utils.jl") include("tracker.jl") include("layers/normalisation.jl") +include("optimise.jl") end From 36001d085a3f9175eaee572e8b8532410a8ebf50 Mon Sep 17 00:00:00 2001 From: baggepinnen Date: Mon, 4 Dec 2017 09:17:05 +0100 Subject: [PATCH 11/18] Implement AMSGrad optimiser --- src/optimise/Optimise.jl | 2 +- src/optimise/interface.jl | 9 +++++++++ src/optimise/optimisers.jl | 14 +++++++++++++- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 5f144b65..acec542e 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,7 @@ module Optimise export update!, params, train!, - SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta + SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad struct Param{T} x::T diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl index 0b2a25ae..c6f98553 100644 --- a/src/optimise/interface.jl +++ b/src/optimise/interface.jl @@ -71,3 +71,12 @@ tuning. """ ADADelta(ps; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) = optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) + + """ + AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + + [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need + tuning. + """ + AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = + optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index abc54090..12a14df4 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -67,8 +67,20 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ function () @. mt = β1 * mt + (1 - β1) * p.Δ @. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 - @. p.Δ = √(1 - β2p) / √(1 - β1p) * mt / √vt * η + @. p.Δ = √(1 - β2p) / (1 - β1p) * mt / √vt * η β1p *= β1 β2p *= β2 end end + +function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) + mt = zeros(p.x) + vt = zeros(p.x) .+ ϵ + v̂t = zeros(p.x) .+ ϵ + function () + @. mt = β1 * mt + (1 - β1) * p.Δ + @. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 + @. v̂t = max.(v̂t, vt) + @. p.Δ = η * mt / √v̂t + end +end From 41febee9c171e610336afd79e1d1480100f29a53 Mon Sep 17 00:00:00 2001 From: baggepinnen Date: Mon, 4 Dec 2017 09:34:27 +0100 Subject: [PATCH 12/18] Export and indent --- src/Flux.jl | 2 +- src/optimise/interface.jl | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index df4b1636..2ae8879f 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -8,7 +8,7 @@ using Juno, Requires using Lazy: @forward export Chain, Dense, RNN, LSTM, Dropout, LayerNorm, - SGD, ADAM, Momentum, Nesterov, + SGD, ADAM, Momentum, Nesterov, AMSGrad, param, params, mapleaves using NNlib diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl index c6f98553..679134fe 100644 --- a/src/optimise/interface.jl +++ b/src/optimise/interface.jl @@ -47,7 +47,7 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) = optimiser(ps, p -> rmsprop(p; η = η, ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) """ - ADAM(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + ADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) [ADAM](https://arxiv.org/abs/1412.6980v8) optimiser. """ @@ -72,11 +72,11 @@ tuning. ADADelta(ps; η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0) = optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) - """ - AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) +""" + AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) - [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need - tuning. - """ - AMSGrad(params; η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = - optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) +[AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) optimiser. Parameters don't need +tuning. +""" +AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = + optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) From 951c21366a54ab60899f2e9955c05bd8ebaedf5b Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 8 Dec 2017 16:42:30 +0000 Subject: [PATCH 13/18] fix regex --- src/data/cmudict.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/data/cmudict.jl b/src/data/cmudict.jl index a23c6a3d..9ec567b4 100644 --- a/src/data/cmudict.jl +++ b/src/data/cmudict.jl @@ -33,8 +33,7 @@ function rawdict() filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n")))) end -# validword(s) = ismatch(r"^[\w-\.]+$", s) -validword(s) = ismatch(r"^\[\w-\.\]+$", s) +validword(s) = ismatch(r"^[\w\-\.]+$", s) cmudict() = filter((s, ps) -> validword(s), rawdict()) From 69cc5642b48b685bbbf109af310384f8eae917e4 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 8 Dec 2017 17:10:29 +0000 Subject: [PATCH 14/18] regression testing --- test/optimise.jl | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/test/optimise.jl b/test/optimise.jl index 85fd53f9..65bb65be 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -2,18 +2,16 @@ using Flux.Optimise using Flux.Tracker @testset "Optimise" begin - loss(x) = sum(x.^2) - η = 0.1 - # RMSProp gets stuck - for OPT in [SGD, Nesterov, Momentum, ADAM, ADAGrad, ADADelta] - x = param(randn(10)) - opt = OPT == ADADelta ? OPT([x]) : OPT([x], η) - for t=1:10000 - l = loss(x) - back!(l) - opt() - l.data[] < 1e-10 && break - end - @test loss(x) ≈ 0. atol=1e-7 + w = randn(10, 10) + for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta] + w′ = param(randn(10, 10)) + loss(x) = Flux.mse(w*x, w′*x) + opt = Opt([w′]) + for t=1:10^5 + l = loss(rand(10)) + back!(l) + opt() end + @test Flux.mse(w, w′) < 0.01 + end end From 55bbe50f32d7dfe58360da9da3832add38a8cc38 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 8 Dec 2017 18:24:07 +0000 Subject: [PATCH 15/18] regression test --- test/optimise.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/optimise.jl b/test/optimise.jl index 65bb65be..526f0534 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -3,7 +3,7 @@ using Flux.Tracker @testset "Optimise" begin w = randn(10, 10) - for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta] + for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad] w′ = param(randn(10, 10)) loss(x) = Flux.mse(w*x, w′*x) opt = Opt([w′]) From 86097e76fdaa149b2caf815d5404c77d16c4f754 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 8 Dec 2017 19:34:34 +0000 Subject: [PATCH 16/18] tweak batchnorm example --- src/layers/normalisation.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index 4eaa6d5b..a018a073 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -81,13 +81,12 @@ in order to normalize the input of other layer, put the `BatchNorm` layer before activation function. ```julia -julia> m = Chain( +m = Chain( Dense(28^2, 64), BatchNorm(64, λ = relu), Dense(64, 10), BatchNorm(10), softmax) -Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax) ``` """ mutable struct BatchNorm{F,V,N} From b7b6c975bc91c3a9c531178ae7715f216698cdd4 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 12 Dec 2017 17:07:39 +0000 Subject: [PATCH 17/18] fixes #110 --- src/tracker/lib.jl | 9 +++++++-- test/tracker.jl | 1 + 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 5065a40d..ab250e39 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -70,7 +70,7 @@ back(::typeof(mean), Δ, xs::TrackedArray, region) = # BLAS -for f in :[*, Ac_mul_B].args +for f in :[*, Ac_mul_B, A_mul_Bc].args @eval begin import Base.$f $f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b)) @@ -94,7 +94,12 @@ end function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real}) @back(a, A_mul_Bt(Δ, data(b))') - @back(b, *(data(a), Δ)) + @back(b, data(a)*Δ) +end + +function back(::typeof(A_mul_Bc), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real}) + @back(a, Δ * data(b)) + @back(b, At_mul_B(data(a), Δ)') end # Fast path for matrix-vector diff --git a/test/tracker.jl b/test/tracker.jl index 81a72566..7d9ef4f5 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -10,6 +10,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) @test gradtest((w, x) -> w'*x, randn(10, 2), randn(10)) +@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5)) @test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5)) From 29787eba452a0e12e7c152fe7ded67393f18a8b7 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 12 Dec 2017 17:23:15 +0000 Subject: [PATCH 18/18] fixes #114 --- src/tracker/lib.jl | 9 +++++++++ test/tracker.jl | 2 ++ 2 files changed, 11 insertions(+) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index ab250e39..f3221bd8 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -58,6 +58,15 @@ Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data))) Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region)) +LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys)))) +LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys)))) +LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys)))) + +function back(::typeof(dot), Δ, xs, ys) + @back(xs, Δ.*ys) + @back(ys, Δ.*xs) +end + # Hacks to get std working Base.std(x::TrackedArray; mean = Base.mean(x)) = sqrt.(sum((x .- mean).^2) ./ (length(x)-1)) diff --git a/test/tracker.jl b/test/tracker.jl index 7d9ef4f5..ac031915 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -38,6 +38,8 @@ end @test gradtest(x -> std(x), rand(5,5)) @test gradtest(x -> std(x, 1), rand(5,5)) +@test gradtest((x, y) -> x .* y, rand(5), rand(5)) + @test gradtest(rand(5)) do x y = x.^2 2y + x