Made Requested Changes
This commit is contained in:
parent
35431e3da9
commit
595f1cf6eb
@ -6,10 +6,8 @@ using Base: tail
|
||||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, Maxout,
|
||||
RNN, LSTM, GRU,
|
||||
Conv, ConvTranspose, MaxPool, MeanPool, DepthwiseConv,
|
||||
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
|
||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||
params, mapleaves, cpu, gpu, f32, f64
|
||||
|
||||
@reexport using NNlib
|
||||
|
@ -286,3 +286,103 @@ function Base.show(io::IO, l::InstanceNorm)
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
"""
|
||||
Group Normalization.
|
||||
Known to improve the overall accuracy in case of classification and segmentation tasks.
|
||||
|
||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
|
||||
ϵ = 1f-5, momentum = 0.1f0)
|
||||
|
||||
chs is the numebr of channels, the channeld dimension of your input.
|
||||
For an array of N dimensions, the (N-1)th index is the channel dimension.
|
||||
|
||||
G is the number of groups along which the statistics would be computed.
|
||||
The number of groups must divide the number of channels for this to work.
|
||||
|
||||
Example:
|
||||
```
|
||||
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
|
||||
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used
|
||||
```
|
||||
|
||||
Link : https://arxiv.org/pdf/1803.08494.pdf
|
||||
"""
|
||||
|
||||
mutable struct GroupNorm{F,V,W,N,T}
|
||||
G::T # number of groups
|
||||
λ::F # activation function
|
||||
β::V # bias
|
||||
γ::V # scale
|
||||
μ::W # moving mean
|
||||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Bool
|
||||
end
|
||||
|
||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
GroupNorm(G, λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(G,1), ones(G,1), ϵ, momentum, true)
|
||||
|
||||
function(gn::GroupNorm)(x)
|
||||
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
|
||||
ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work")
|
||||
(size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))")
|
||||
|
||||
dims = length(size(x))
|
||||
groups = gn.G
|
||||
channels = size(x, dims-1)
|
||||
batches = size(x,dims)
|
||||
channels_per_group = div(channels,groups)
|
||||
affine_shape = ones(Int, dims)
|
||||
|
||||
# Output reshaped to (W,H...,C/G,G,N)
|
||||
affine_shape[end-1] = channels
|
||||
|
||||
m = prod(size(x)[1:end-2]) * channels_per_group
|
||||
γ = reshape(gn.γ, affine_shape...)
|
||||
β = reshape(gn.β, affine_shape...)
|
||||
if !gn.active
|
||||
μ = reshape(gn.μ, affine_shape...)
|
||||
σ² = reshape(gn.σ², affine_shape...)
|
||||
ϵ = gn.ϵ
|
||||
else
|
||||
T = eltype(x)
|
||||
og_shape = size(x)
|
||||
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
||||
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
|
||||
μ = mean(y, dims = axes)
|
||||
σ² = sum((y .- μ) .^ 2, dims = axes) ./ m
|
||||
ϵ = data(convert(T, gn.ϵ))
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, gn.momentum))
|
||||
|
||||
gn.μ = (1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches))
|
||||
gn.σ² = (1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches))
|
||||
end
|
||||
|
||||
let λ = gn.λ
|
||||
x̂ = (y .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||
|
||||
# Reshape x̂
|
||||
x̂ = reshape(x̂,og_shape)
|
||||
λ.(γ .* x̂ .+ β)
|
||||
end
|
||||
end
|
||||
|
||||
children(gn::GroupNorm) =
|
||||
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum, gn.active)
|
||||
|
||||
mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
|
||||
GroupNorm(gn,G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum, gn.active)
|
||||
|
||||
_testmode!(gn::GroupNorm, test) = (gn.active = !test)
|
||||
|
||||
function Base.show(io::IO, l::GroupNorm)
|
||||
print(io, "GroupNorm($(join(size(l.β), ", "))")
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
print(io, ")")
|
||||
end
|
Loading…
Reference in New Issue
Block a user