diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index a105e1ba..f7a425b2 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -92,12 +92,12 @@ function (BN::BatchNorm)(x) ϵ = T(BN.ϵ) m = size(x, 2) # batch size μ = sum(x, 2) ./ m - σ = sqrt.(sum((x .- μ).^2, 2) ./ (m - 1) .+ ϵ) + σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ) # update moving mean/std mtm = T(BN.momentum) - BN.μ = mtm .* μ.data .+ (1 - mtm) .* BN.μ - BN.σ = mtm .* σ.data .+ (1 - mtm) .* BN.σ + BN.μ = (1 - mtm) .* BN.μ .+ mtm .* μ.data + BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1) end BN.λ.(BN.γ .* ((x .- μ) ./ σ) .+ BN.β)