From a7143553df6375b97a5ee4d3fd6ff0bc56c1ffb9 Mon Sep 17 00:00:00 2001 From: Elliot Saba Date: Sun, 20 Jan 2019 23:57:19 +0000 Subject: [PATCH] =?UTF-8?q?Change=20name=20to=20=CF=83=C2=B2=20for=20bette?= =?UTF-8?q?r=20consistency?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/layers/normalise.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index c25fe798..9cc6876a 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -72,7 +72,7 @@ function Base.show(io::IO, l::LayerNorm) end """ - BatchNorm(channels::Integer, σ = identity; + BatchNorm(channels::Integer, σ² = identity; initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) @@ -102,11 +102,11 @@ m = Chain( ``` """ mutable struct BatchNorm{F,V,W,N} - λ::F # activation function - β::V # bias - γ::V # scale - μ::W # moving mean - σ::W # moving std + λ::F # activation function + β::V # bias + γ::V # scale + μ::W # moving mean + σ²::W # moving std ϵ::N momentum::N active::Bool @@ -132,31 +132,31 @@ function (BN::BatchNorm)(x) if !BN.active μ = reshape(BN.μ, affine_shape...) - σ = reshape(BN.σ, affine_shape...) + σ² = reshape(BN.σ², affine_shape...) else T = eltype(data(x)) axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) μ = mean(x, dims = axes) meansub = (x .- μ) - σ = mean(meansub .* meansub, dims = axes) + σ² = mean(meansub .* meansub, dims = axes) # update moving mean/std mtm = convert(T, data(BN.momentum)) BN.μ = (1 - mtm) .* BN.μ .+ mtm .* data(reshape(μ, :)) - BN.σ = ((1 - mtm) .* BN.σ .+ mtm .* data(reshape(σ, :)) .* m ./ (m - 1)) + BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* data(reshape(σ², :)) .* m ./ (m - 1)) end let λ = BN.λ, ϵ = eltype(data(σ²))(BN.ϵ) - λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ .+ ϵ)) .+ reshape(β, affine_shape...)) + λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ ϵ)) .+ reshape(β, affine_shape...)) end end children(BN::BatchNorm) = - (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active) + (BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum, BN.active) mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) - BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum, BN.active) + BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum, BN.active) _testmode!(BN::BatchNorm, test) = (BN.active = !test)