diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 59b39ca7..2876cdd7 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -136,7 +136,7 @@ function (BN::BatchNorm)(x) dims = length(size(x)) channels = size(x, dims-1) affine_shape = ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x)) - m = prod(size(x)[1:end-2]) * size(x)[end] + m = trunc(Int, prod(size(x))/channels) γ = reshape(BN.γ, affine_shape...) β = reshape(BN.β, affine_shape...) if !istraining()