diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 1783d3ef..e48d26fb 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -113,34 +113,32 @@ BatchNorm(chs::Integer, λ = identity; function (BN::BatchNorm)(x) size(x, ndims(x)-1) == length(BN.β) || error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))") - γ, β = BN.γ, BN.β dims = length(size(x)) channels = size(x, dims-1) affine_shape = ones(Int, dims) affine_shape[end-1] = channels m = prod(size(x)[1:end-2]) * size(x)[end] - + γ = reshape(BN.γ, affine_shape...) + β = reshape(BN.β, affine_shape...) if !BN.active μ = reshape(BN.μ, affine_shape...) σ² = reshape(BN.σ², affine_shape...) + ϵ = BN.ϵ else T = eltype(x) - - ϵ = data(convert(T, BN.ϵ)) axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) μ = mean(x, dims = axes) σ² = sum((x .- μ) .^ 2, dims = axes) ./ m - + ϵ = data(convert(T, BN.ϵ)) # update moving mean/std mtm = data(convert(T, BN.momentum)) BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :) - BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1)) + BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), :) end let λ = BN.λ - temp = reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ)) - # This is intentionally not fused because of an extreme slowdown doing so - λ.(temp .+ reshape(β, affine_shape...)) + x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ) + λ.(γ .* x̂ .+ β) end end