diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 1fd87d41..d92388e1 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -37,6 +37,7 @@ These layers don't affect the structure of the network but may improve training ```@docs Flux.testmode! +BatchNorm Dropout LayerNorm ``` diff --git a/src/Flux.jl b/src/Flux.jl index f33351b7..526d6bb8 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,7 +7,8 @@ module Flux using Juno, Requires using Lazy: @forward -export Chain, Dense, RNN, LSTM, Dropout, LayerNorm, +export Chain, Dense, RNN, LSTM, + Dropout, LayerNorm, BatchNorm, SGD, ADAM, Momentum, Nesterov, AMSGrad, param, params, mapleaves diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index d296b0a3..4eaa6d5b 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -2,8 +2,8 @@ testmode!(m) testmode!(m, false) -Put layers like [`Dropout`](@ref) and `BatchNorm` into testing mode (or back to -training mode with `false`). +Put layers like [`Dropout`](@ref) and [`BatchNorm`](@ref) into testing mode +(or back to training mode with `false`). """ function testmode!(m, val::Bool=true) prefor(x -> _testmode!(x, val), m) @@ -45,6 +45,7 @@ end _testmode!(a::Dropout, test) = (a.active = !test) """ + LayerNorm(h::Integer) A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be @@ -65,3 +66,78 @@ treelike(LayerNorm) function Base.show(io::IO, l::LayerNorm) print(io, "LayerNorm(", length(l.diag.α), ")") end + +""" + BatchNorm(dims...; λ = identity, + initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) + +Batch Normalization Layer for [`Dense`](@ref) layer. + +See [Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf) + +In the example of MNIST, +in order to normalize the input of other layer, +put the `BatchNorm` layer before activation function. + +```julia +julia> m = Chain( + Dense(28^2, 64), + BatchNorm(64, λ = relu), + Dense(64, 10), + BatchNorm(10), + softmax) +Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax) +``` +""" +mutable struct BatchNorm{F,V,N} + λ::F # activation function + β::V # bias + γ::V # scale + μ # moving mean + σ # moving std + ϵ::N + momentum::N + active::Bool +end + +BatchNorm(dims::Integer...; λ = identity, + initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) = + BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true) + +function (BN::BatchNorm)(x) + λ, γ, β = BN.λ, BN.γ, BN.β + + if !BN.active + μ = BN.μ + σ = BN.σ + else + T = eltype(x) + + ϵ = T(BN.ϵ) + m = size(x, 2) # batch size + μ = mean(x, 2) + σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ) + + # update moving mean/std + mtm = T(BN.momentum) + BN.μ = (1 - mtm) .* BN.μ .+ mtm .* μ.data + BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1) + end + + λ.(γ .* ((x .- μ) ./ σ) .+ β) +end + +children(BN::BatchNorm) = + (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) + +mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) + BatchNorm(BN.λ, f(BN.β), f(BN.γ), BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active) + +_testmode!(BN::BatchNorm, test) = (BN.active = !test) + +function Base.show(io::IO, l::BatchNorm) + print(io, "BatchNorm($(join(size(l.β), ", "))") + (l.λ == identity) || print(io, ", λ = $(l.λ)") + print(io, ")") +end diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 5a302a51..118a5700 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -26,3 +26,55 @@ using Flux: testmode! y = m(x) @test count(a->a == 0, y) == 0 end + +@testset "BatchNorm" begin + let m = BatchNorm(2), x = param([1 2; 3 4; 5 6]') + + @test m.β.data == [0, 0] # initβ(2) + @test m.γ.data == [1, 1] # initγ(2) + # initial m.σ is 1 + # initial m.μ is 0 + @test m.active + + # @test m(x).data ≈ [-1 -1; 0 0; 1 1]' + m(x) + + # julia> x + # 2×3 Array{Float64,2}: + # 1.0 3.0 5.0 + # 2.0 4.0 6.0 + # + # μ of batch will be + # (1. + 3. + 5.) / 3 = 3 + # (2. + 4. + 6.) / 3 = 4 + # + # ∴ update rule with momentum: + # .1 * 3 + 0 = .3 + # .1 * 4 + 0 = .4 + @test m.μ ≈ reshape([0.3, 0.4], 2, 1) + + # julia> .1 .* std(x, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + # 2×1 Array{Float64,2}: + # 1.14495 + # 1.14495 + @test m.σ ≈ .1 .* std(x.data, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + + testmode!(m) + @test !m.active + + x′ = m(x).data + @test x′[1] ≈ (1 - 0.3) / 1.1449489742783179 + end + + # with activation function + let m = BatchNorm(2, λ = σ), x = param([1 2; 3 4; 5 6]') + @test m.active + m(x) + + testmode!(m) + @test !m.active + + x′ = m(x).data + @test x′[1] ≈ σ((1 - 0.3) / 1.1449489742783179) + end +end