diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 5a302a51..e3115f67 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -26,3 +26,43 @@ using Flux: testmode! y = m(x) @test count(a->a == 0, y) == 0 end + +@testset "BatchNorm" begin + let m = BatchNorm(2), x = param([1 2; 3 4; 5 6]') + + @test m.β.data == [0, 0] # initβ(2) + @test m.γ.data == [1, 1] # initγ(2) + # initial m.σ is 1 + # initial m.μ is 0 + @test m.active + + # @test m(x).data ≈ [-1 -1; 0 0; 1 1]' + m(x) + + # julia> x + # 2×3 Array{Float64,2}: + # 1.0 3.0 5.0 + # 2.0 4.0 6.0 + # + # μ of batch will be + # (1. + 3. + 5.) / 3 = 3 + # (2. + 4. + 6.) / 3 = 4 + # + # ∴ update rule with momentum: + # .1 * 3 + 0 = .3 + # .1 * 4 + 0 = .4 + @test m.μ ≈ reshape([0.3, 0.4], 2, 1) + + # julia> .1 .* std(x, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + # 2×1 Array{Float64,2}: + # 1.14495 + # 1.14495 + @test m.σ ≈ .1 .* std(x.data, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + + testmode!(m) + @test !m.active + + x′ = m(x).data + @test x′[1] ≈ (1 - 0.3) / 1.1449489742783179 + end +end