diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index c25fe798..9cc6876a 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -72,7 +72,7 @@ function Base.show(io::IO, l::LayerNorm) end """ - BatchNorm(channels::Integer, σ = identity; + BatchNorm(channels::Integer, σ² = identity; initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) @@ -102,11 +102,11 @@ m = Chain( ``` """ mutable struct BatchNorm{F,V,W,N} - λ::F # activation function - β::V # bias - γ::V # scale - μ::W # moving mean - σ::W # moving std + λ::F # activation function + β::V # bias + γ::V # scale + μ::W # moving mean + σ²::W # moving std ϵ::N momentum::N active::Bool @@ -132,31 +132,31 @@ function (BN::BatchNorm)(x) if !BN.active μ = reshape(BN.μ, affine_shape...) - σ = reshape(BN.σ, affine_shape...) + σ² = reshape(BN.σ², affine_shape...) else T = eltype(data(x)) axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) μ = mean(x, dims = axes) meansub = (x .- μ) - σ = mean(meansub .* meansub, dims = axes) + σ² = mean(meansub .* meansub, dims = axes) # update moving mean/std mtm = convert(T, data(BN.momentum)) BN.μ = (1 - mtm) .* BN.μ .+ mtm .* data(reshape(μ, :)) - BN.σ = ((1 - mtm) .* BN.σ .+ mtm .* data(reshape(σ, :)) .* m ./ (m - 1)) + BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* data(reshape(σ², :)) .* m ./ (m - 1)) end let λ = BN.λ, ϵ = eltype(data(σ²))(BN.ϵ) - λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ .+ ϵ)) .+ reshape(β, affine_shape...)) + λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ ϵ)) .+ reshape(β, affine_shape...)) end end children(BN::BatchNorm) = - (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active) + (BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum, BN.active) mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) - BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum, BN.active) + BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum, BN.active) _testmode!(BN::BatchNorm, test) = (BN.active = !test)