From 6044421c5cb2660c7baaff1d9500fa0f871e045b Mon Sep 17 00:00:00 2001 From: Sklan Date: Wed, 20 Feb 2019 13:47:31 +0530 Subject: [PATCH 1/2] Update normalise.jl --- src/layers/normalise.jl | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 1783d3ef..b9f7a86c 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -113,34 +113,32 @@ BatchNorm(chs::Integer, λ = identity; function (BN::BatchNorm)(x) size(x, ndims(x)-1) == length(BN.β) || error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))") - γ, β = BN.γ, BN.β dims = length(size(x)) channels = size(x, dims-1) affine_shape = ones(Int, dims) affine_shape[end-1] = channels m = prod(size(x)[1:end-2]) * size(x)[end] - + γ = reshape(BN.γ, affine_shape...) + β = reshape(BN.β, affine_shape...) if !BN.active μ = reshape(BN.μ, affine_shape...) σ² = reshape(BN.σ², affine_shape...) + ϵ = BN.ϵ else T = eltype(x) - - ϵ = data(convert(T, BN.ϵ)) axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) μ = mean(x, dims = axes) σ² = sum((x .- μ) .^ 2, dims = axes) ./ m - + ϵ = data(convert(T, BN.ϵ)) # update moving mean/std mtm = data(convert(T, BN.momentum)) BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :) - BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1)) + BN.σ² = (1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1) end let λ = BN.λ - temp = reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ)) - # This is intentionally not fused because of an extreme slowdown doing so - λ.(temp .+ reshape(β, affine_shape...)) + x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ) + λ.(γ .* x̂ .+ β) end end From 7463f0959176912d7a41bc57e0f20e7b14bf4902 Mon Sep 17 00:00:00 2001 From: Sklan Date: Thu, 21 Feb 2019 23:56:19 +0530 Subject: [PATCH 2/2] Update normalise.jl --- src/layers/normalise.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index b9f7a86c..e48d26fb 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -133,7 +133,7 @@ function (BN::BatchNorm)(x) # update moving mean/std mtm = data(convert(T, BN.momentum)) BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :) - BN.σ² = (1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1) + BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), :) end let λ = BN.λ