diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index fc781f70..0647e6b4 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -157,6 +157,8 @@ mutable struct BatchNorm{F,V,W,N} active::Union{Bool, Nothing} end +BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) = BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing) + BatchNorm(chs::Integer, λ = identity; initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = BatchNorm(λ, initβ(chs), initγ(chs),