diff --git a/src/Flux.jl b/src/Flux.jl index ff178450..2a665336 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -6,10 +6,8 @@ using Base: tail using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward -export Chain, Dense, Maxout, - RNN, LSTM, GRU, - Conv, ConvTranspose, MaxPool, MeanPool, DepthwiseConv, - Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, +export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, + DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, params, mapleaves, cpu, gpu, f32, f64 @reexport using NNlib diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 5fd93e9d..0e66e1ab 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -286,3 +286,103 @@ function Base.show(io::IO, l::InstanceNorm) (l.λ == identity) || print(io, ", λ = $(l.λ)") print(io, ")") end + +""" +Group Normalization. +Known to improve the overall accuracy in case of classification and segmentation tasks. + +GroupNorm(chs::Integer, G::Integer, λ = identity; + initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), + ϵ = 1f-5, momentum = 0.1f0) + +chs is the numebr of channels, the channeld dimension of your input. +For an array of N dimensions, the (N-1)th index is the channel dimension. + +G is the number of groups along which the statistics would be computed. +The number of groups must divide the number of channels for this to work. + +Example: +``` +m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1), + GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used +``` + +Link : https://arxiv.org/pdf/1803.08494.pdf +""" + +mutable struct GroupNorm{F,V,W,N,T} + G::T # number of groups + λ::F # activation function + β::V # bias + γ::V # scale + μ::W # moving mean + σ²::W # moving std + ϵ::N + momentum::N + active::Bool +end + +GroupNorm(chs::Integer, G::Integer, λ = identity; + initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = + GroupNorm(G, λ, param(initβ(chs)), param(initγ(chs)), + zeros(G,1), ones(G,1), ϵ, momentum, true) + +function(gn::GroupNorm)(x) + size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels") + ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work") + (size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))") + + dims = length(size(x)) + groups = gn.G + channels = size(x, dims-1) + batches = size(x,dims) + channels_per_group = div(channels,groups) + affine_shape = ones(Int, dims) + + # Output reshaped to (W,H...,C/G,G,N) + affine_shape[end-1] = channels + + m = prod(size(x)[1:end-2]) * channels_per_group + γ = reshape(gn.γ, affine_shape...) + β = reshape(gn.β, affine_shape...) + if !gn.active + μ = reshape(gn.μ, affine_shape...) + σ² = reshape(gn.σ², affine_shape...) + ϵ = gn.ϵ + else + T = eltype(x) + og_shape = size(x) + y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches)) + axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis) + μ = mean(y, dims = axes) + σ² = sum((y .- μ) .^ 2, dims = axes) ./ m + ϵ = data(convert(T, gn.ϵ)) + # update moving mean/std + mtm = data(convert(T, gn.momentum)) + + gn.μ = (1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches)) + gn.σ² = (1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches)) + end + + let λ = gn.λ + x̂ = (y .- μ) ./ sqrt.(σ² .+ ϵ) + + # Reshape x̂ + x̂ = reshape(x̂,og_shape) + λ.(γ .* x̂ .+ β) + end +end + +children(gn::GroupNorm) = + (gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum, gn.active) + +mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN) + GroupNorm(gn,G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum, gn.active) + +_testmode!(gn::GroupNorm, test) = (gn.active = !test) + +function Base.show(io::IO, l::GroupNorm) + print(io, "GroupNorm($(join(size(l.β), ", "))") + (l.λ == identity) || print(io, ", λ = $(l.λ)") + print(io, ")") +end \ No newline at end of file