diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 88c06673..c0d4aabf 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -114,6 +114,8 @@ BatchNorm(chs::Integer, λ = identity; zeros(chs), ones(chs), ϵ, momentum, true) function (BN::BatchNorm)(x) + size(x, ndims(x)-1) == length(BN.β) || + error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))") λ, γ, β = BN.λ, BN.γ, BN.β dims = length(size(x)) channels = size(x, dims-1)