batchnorm: add test cases
This commit is contained in:
parent
b3356cc6bb
commit
ce46843459
@ -26,3 +26,43 @@ using Flux: testmode!
|
||||
y = m(x)
|
||||
@test count(a->a == 0, y) == 0
|
||||
end
|
||||
|
||||
@testset "BatchNorm" begin
|
||||
let m = BatchNorm(2), x = param([1 2; 3 4; 5 6]')
|
||||
|
||||
@test m.β.data == [0, 0] # initβ(2)
|
||||
@test m.γ.data == [1, 1] # initγ(2)
|
||||
# initial m.σ is 1
|
||||
# initial m.μ is 0
|
||||
@test m.active
|
||||
|
||||
# @test m(x).data ≈ [-1 -1; 0 0; 1 1]'
|
||||
m(x)
|
||||
|
||||
# julia> x
|
||||
# 2×3 Array{Float64,2}:
|
||||
# 1.0 3.0 5.0
|
||||
# 2.0 4.0 6.0
|
||||
#
|
||||
# μ of batch will be
|
||||
# (1. + 3. + 5.) / 3 = 3
|
||||
# (2. + 4. + 6.) / 3 = 4
|
||||
#
|
||||
# ∴ update rule with momentum:
|
||||
# .1 * 3 + 0 = .3
|
||||
# .1 * 4 + 0 = .4
|
||||
@test m.μ ≈ reshape([0.3, 0.4], 2, 1)
|
||||
|
||||
# julia> .1 .* std(x, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
# 2×1 Array{Float64,2}:
|
||||
# 1.14495
|
||||
# 1.14495
|
||||
@test m.σ ≈ .1 .* std(x.data, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
x′ = m(x).data
|
||||
@test x′[1] ≈ (1 - 0.3) / 1.1449489742783179
|
||||
end
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user