diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 3755f3fc..7d1d4d0a 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -366,12 +366,10 @@ function(gn::GroupNorm)(x) end children(gn::GroupNorm) = - (gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum, gn.active) + (gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum) mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN) - GroupNorm(gn.G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum, gn.active) - -_testmode!(gn::GroupNorm, test) = (gn.active = !test) + GroupNorm(gn.G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum) function Base.show(io::IO, l::GroupNorm) print(io, "GroupNorm($(join(size(l.β), ", "))")