diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index b9f7a86c..e48d26fb 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -133,7 +133,7 @@ function (BN::BatchNorm)(x) # update moving mean/std mtm = data(convert(T, BN.momentum)) BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :) - BN.σ² = (1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1) + BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), :) end let λ = BN.λ