diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 9528cec4..d02aee35 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -264,11 +264,11 @@ function Base.show(io::IO, l::InstanceNorm) end """ -Group Normalization. +Group Normalization. This layer can outperform Batch-Normalization and Instance-Normalization. GroupNorm(chs::Integer, G::Integer, λ = identity; - initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), + initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) ``chs`` is the number of channels, the channel dimension of your input. @@ -280,7 +280,7 @@ The number of channels must be an integer multiple of the number of groups. Example: ``` m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1), - GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used + GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used ``` Link : https://arxiv.org/pdf/1803.08494.pdf @@ -295,7 +295,6 @@ mutable struct GroupNorm{F,V,W,N,T} σ²::W # moving std ϵ::N momentum::N - active::Bool end GroupNorm(chs::Integer, G::Integer, λ = identity; @@ -324,9 +323,9 @@ function(gn::GroupNorm)(x) m = prod(size(x)[1:end-2]) * channels_per_group γ = reshape(gn.γ, affine_shape...) β = reshape(gn.β, affine_shape...) - + y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches)) - if !gn.active + if !istraining() og_shape = size(x) μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1) σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1) @@ -337,7 +336,7 @@ function(gn::GroupNorm)(x) axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis) μ = mean(y, dims = axes) σ² = mean((y .- μ) .^ 2, dims = axes) - + ϵ = data(convert(T, gn.ϵ)) # update moving mean/std mtm = data(convert(T, gn.momentum)) @@ -349,7 +348,7 @@ function(gn::GroupNorm)(x) let λ = gn.λ x̂ = (y .- μ) ./ sqrt.(σ² .+ ϵ) - # Reshape x̂ + # Reshape x̂ x̂ = reshape(x̂,og_shape) λ.(γ .* x̂ .+ β) end