From a140c31f72616bf501b69c909362c2f643d2fd41 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Fri, 12 Jul 2019 16:09:42 +0100 Subject: [PATCH] fix batchnorm --- src/layers/normalise.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index b4d3a035..59b39ca7 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -135,8 +135,7 @@ function (BN::BatchNorm)(x) error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))") dims = length(size(x)) channels = size(x, dims-1) - affine_shape = ones(Int, dims) - affine_shape[end-1] = channels + affine_shape = ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x)) m = prod(size(x)[1:end-2]) * size(x)[end] γ = reshape(BN.γ, affine_shape...) β = reshape(BN.β, affine_shape...) @@ -151,9 +150,10 @@ function (BN::BatchNorm)(x) σ² = sum((x .- μ) .^ 2, dims = axes) ./ m ϵ = convert(T, BN.ϵ) # update moving mean/std - mtm = convert(T, BN.momentum) - BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(μ, :) - BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(σ², :) + mtm = BN.momentum + S = eltype(BN.μ) + BN.μ = (1 - mtm) .* BN.μ .+ mtm .* S.(reshape(μ, :)) + BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* S.(reshape(σ², :)) end let λ = BN.λ