From 669273b008f4967d578dd8e692045973a16dbd9b Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Tue, 17 Oct 2017 17:26:15 +0800 Subject: [PATCH 1/9] layer: implement BatchNorm layer See [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf) --- src/Flux.jl | 2 +- src/layers/normalisation.jl | 69 +++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/src/Flux.jl b/src/Flux.jl index 242c8b1f..ad85efcb 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,7 +7,7 @@ module Flux using Juno, Requires using Lazy: @forward -export Chain, Dense, RNN, LSTM, Dropout, +export BatchNorm, Chain, Dense, RNN, LSTM, Dropout, SGD, ADAM, Momentum, Nesterov, param, params, mapleaves diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index 08c21428..10425d96 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -43,3 +43,72 @@ function (a::Dropout)(x) end _testmode!(a::Dropout, test) = (a.active = !test) + +""" + BatchNorm(dims...; λ = identity, + initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) + +Batch Normalization Layer + +See [Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf) + +In the example of MNIST, +in order to normalize the input of other layer, +put the `BatchNorm` layer before activation function. + +```julia +julia> m = Chain( + Dense(28^2, 64), + BatchNorm(64, λ = relu), + Dense(64, 10), + BatchNorm(10), + softmax) +Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax) +``` +""" +mutable struct BatchNorm{F,V} + λ::F # activation function + β::V # bias + γ::V # scale + μ # moving mean + σ # moving std + ϵ::Float64 + momentum::Float64 + active::Bool +end + +BatchNorm(dims::Integer...; λ = identity, + initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) = + BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., momentum, ϵ, true) + +function (BN::BatchNorm)(x) + if !BN.active + μ = BN.μ + σ = BN.σ + else + m = size(x, 2) # batch size + μ = sum(x, 2) ./ m + σ = sqrt.(sum((x .- μ).^2, 2) ./ (m - 1) .+ BN.ϵ) + + # update moving mean/std + mtm = BN.momentum + BN.μ = mtm .* μ.data .+ (1 - mtm) .* BN.μ + BN.σ = mtm .* σ.data .+ (1 - mtm) .* BN.σ + end + + BN.λ.(BN.γ .* ((x .- μ) ./ σ) .+ BN.β) +end + +children(BN::BatchNorm) = + (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) +mapchildren(f, BN::BatchNorm) = + BatchNorm(λ, f(BN.β), f(BN.γ), BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) + +_testmode!(BN::BatchNorm, test) = (BN.active = !test) + +function Base.show(io::IO, l::BatchNorm) + print(io, "BatchNorm($(join(size(l.β), ", "))") + (l.λ == identity) || print(io, ", λ = $(l.λ)") + print(io, ")") +end From e0201be7704d82eb7d674af753fe3b81beda8ba9 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Mon, 30 Oct 2017 11:34:51 +0800 Subject: [PATCH 2/9] batchnorm: parameterize momentum and epsilon --- src/layers/normalisation.jl | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index 10425d96..a105e1ba 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -67,32 +67,35 @@ julia> m = Chain( Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax) ``` """ -mutable struct BatchNorm{F,V} +mutable struct BatchNorm{F,V,N} λ::F # activation function β::V # bias γ::V # scale μ # moving mean σ # moving std - ϵ::Float64 - momentum::Float64 + ϵ::N + momentum::N active::Bool end BatchNorm(dims::Integer...; λ = identity, initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) = - BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., momentum, ϵ, true) + BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true) function (BN::BatchNorm)(x) if !BN.active μ = BN.μ σ = BN.σ else + T = eltype(x) + + ϵ = T(BN.ϵ) m = size(x, 2) # batch size μ = sum(x, 2) ./ m - σ = sqrt.(sum((x .- μ).^2, 2) ./ (m - 1) .+ BN.ϵ) + σ = sqrt.(sum((x .- μ).^2, 2) ./ (m - 1) .+ ϵ) # update moving mean/std - mtm = BN.momentum + mtm = T(BN.momentum) BN.μ = mtm .* μ.data .+ (1 - mtm) .* BN.μ BN.σ = mtm .* σ.data .+ (1 - mtm) .* BN.σ end @@ -102,7 +105,8 @@ end children(BN::BatchNorm) = (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) -mapchildren(f, BN::BatchNorm) = + +mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) BatchNorm(λ, f(BN.β), f(BN.γ), BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) _testmode!(BN::BatchNorm, test) = (BN.active = !test) From b3356cc6bb91bedf610811031504dfabb1d1bc27 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Mon, 30 Oct 2017 12:57:30 +0800 Subject: [PATCH 3/9] =?UTF-8?q?batchnorm:=20batch=20=CF=83=20correct=20coe?= =?UTF-8?q?fficient?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/layers/normalisation.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index a105e1ba..f7a425b2 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -92,12 +92,12 @@ function (BN::BatchNorm)(x) ϵ = T(BN.ϵ) m = size(x, 2) # batch size μ = sum(x, 2) ./ m - σ = sqrt.(sum((x .- μ).^2, 2) ./ (m - 1) .+ ϵ) + σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ) # update moving mean/std mtm = T(BN.momentum) - BN.μ = mtm .* μ.data .+ (1 - mtm) .* BN.μ - BN.σ = mtm .* σ.data .+ (1 - mtm) .* BN.σ + BN.μ = (1 - mtm) .* BN.μ .+ mtm .* μ.data + BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1) end BN.λ.(BN.γ .* ((x .- μ) ./ σ) .+ BN.β) From ce468434593ada3abd2332a052fb84db709bc4ac Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Mon, 30 Oct 2017 13:24:35 +0800 Subject: [PATCH 4/9] batchnorm: add test cases --- test/layers/normalisation.jl | 40 ++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 5a302a51..e3115f67 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -26,3 +26,43 @@ using Flux: testmode! y = m(x) @test count(a->a == 0, y) == 0 end + +@testset "BatchNorm" begin + let m = BatchNorm(2), x = param([1 2; 3 4; 5 6]') + + @test m.β.data == [0, 0] # initβ(2) + @test m.γ.data == [1, 1] # initγ(2) + # initial m.σ is 1 + # initial m.μ is 0 + @test m.active + + # @test m(x).data ≈ [-1 -1; 0 0; 1 1]' + m(x) + + # julia> x + # 2×3 Array{Float64,2}: + # 1.0 3.0 5.0 + # 2.0 4.0 6.0 + # + # μ of batch will be + # (1. + 3. + 5.) / 3 = 3 + # (2. + 4. + 6.) / 3 = 4 + # + # ∴ update rule with momentum: + # .1 * 3 + 0 = .3 + # .1 * 4 + 0 = .4 + @test m.μ ≈ reshape([0.3, 0.4], 2, 1) + + # julia> .1 .* std(x, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + # 2×1 Array{Float64,2}: + # 1.14495 + # 1.14495 + @test m.σ ≈ .1 .* std(x.data, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + + testmode!(m) + @test !m.active + + x′ = m(x).data + @test x′[1] ≈ (1 - 0.3) / 1.1449489742783179 + end +end From 5253841acce973def6d2fb24e235f1ad63bca8da Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Mon, 30 Oct 2017 13:33:01 +0800 Subject: [PATCH 5/9] batchnorm: update docs --- docs/src/models/layers.md | 1 + src/layers/normalisation.jl | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 5d5d2ee8..98abae3b 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -36,5 +36,6 @@ swish These layers don't affect the structure of the network but may improve training times or reduce overfitting. ```@docs +BatchNorm Dropout ``` diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index f7a425b2..1bd44b62 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -2,8 +2,8 @@ testmode!(m) testmode!(m, false) -Put layers like [`Dropout`](@ref) and `BatchNorm` into testing mode (or back to -training mode with `false`). +Put layers like [`Dropout`](@ref) and [`BatchNorm`](@ref) into testing mode +(or back to training mode with `false`). """ function testmode!(m, val::Bool=true) prefor(x -> _testmode!(x, val), m) @@ -48,7 +48,7 @@ _testmode!(a::Dropout, test) = (a.active = !test) BatchNorm(dims...; λ = identity, initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) -Batch Normalization Layer +Batch Normalization Layer for [`Dense`](@ref) layer. See [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf) @@ -65,6 +65,8 @@ julia> m = Chain( BatchNorm(10), softmax) Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax) + +julia> opt = SGD(params(m), 10) # a crazy learning rate ``` """ mutable struct BatchNorm{F,V,N} From 7f5ba594a958c5623c149d5a2caab64424b63298 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Mon, 30 Oct 2017 13:37:48 +0800 Subject: [PATCH 6/9] batchnorm: more test cases --- test/layers/normalisation.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index e3115f67..118a5700 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -65,4 +65,16 @@ end x′ = m(x).data @test x′[1] ≈ (1 - 0.3) / 1.1449489742783179 end + + # with activation function + let m = BatchNorm(2, λ = σ), x = param([1 2; 3 4; 5 6]') + @test m.active + m(x) + + testmode!(m) + @test !m.active + + x′ = m(x).data + @test x′[1] ≈ σ((1 - 0.3) / 1.1449489742783179) + end end From 477da754285f1aab8719a8cbee698db0619cebb1 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Mon, 30 Oct 2017 13:42:00 +0800 Subject: [PATCH 7/9] batchnorm: fix mapchildren --- src/layers/normalisation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index 1bd44b62..ee606b40 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -109,7 +109,7 @@ children(BN::BatchNorm) = (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) - BatchNorm(λ, f(BN.β), f(BN.γ), BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) + BatchNorm(BN.λ, f(BN.β), f(BN.γ), BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) _testmode!(BN::BatchNorm, test) = (BN.active = !test) From 88bd8a8fbd17139c9d2a5ef01cb575e73f604bf9 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Thu, 2 Nov 2017 13:40:06 +0800 Subject: [PATCH 8/9] batchnorm: make CuArrays happy --- src/layers/normalisation.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index ee606b40..d4a9c94e 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -66,7 +66,7 @@ julia> m = Chain( softmax) Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax) -julia> opt = SGD(params(m), 10) # a crazy learning rate +julia> opt = SGD(params(m), 10, decay = .1) # a crazy learning rate ``` """ mutable struct BatchNorm{F,V,N} @@ -85,6 +85,8 @@ BatchNorm(dims::Integer...; λ = identity, BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true) function (BN::BatchNorm)(x) + λ, γ, β = BN.λ, BN.γ, BN.β + if !BN.active μ = BN.μ σ = BN.σ @@ -102,7 +104,7 @@ function (BN::BatchNorm)(x) BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1) end - BN.λ.(BN.γ .* ((x .- μ) ./ σ) .+ BN.β) + λ.(γ .* ((x .- μ) ./ σ) .+ β) end children(BN::BatchNorm) = From 6c7613e02b2ce102d742c516caa715e4bf67538a Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Thu, 2 Nov 2017 14:20:34 +0800 Subject: [PATCH 9/9] batchnorm: leverage TrackedArray `mean` --- src/layers/normalisation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index d4a9c94e..bd1425d8 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -95,7 +95,7 @@ function (BN::BatchNorm)(x) ϵ = T(BN.ϵ) m = size(x, 2) # batch size - μ = sum(x, 2) ./ m + μ = mean(x, 2) σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ) # update moving mean/std