diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index c0d4aabf..74905a36 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -116,7 +116,7 @@ BatchNorm(chs::Integer, λ = identity; function (BN::BatchNorm)(x) size(x, ndims(x)-1) == length(BN.β) || error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))") - λ, γ, β = BN.λ, BN.γ, BN.β + γ, β = BN.γ, BN.β dims = length(size(x)) channels = size(x, dims-1) affine_shape = ones(Int, dims) @@ -140,7 +140,9 @@ function (BN::BatchNorm)(x) BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1) end - λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...)) + let λ = BN.λ + λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...)) + end end children(BN::BatchNorm) =