Made Requested Changes

This commit is contained in:
Shreyas 2019-03-26 21:42:49 +05:30
parent 35431e3da9
commit 595f1cf6eb
2 changed files with 102 additions and 4 deletions

View File

@ -6,10 +6,8 @@ using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward
export Chain, Dense, Maxout,
RNN, LSTM, GRU,
Conv, ConvTranspose, MaxPool, MeanPool, DepthwiseConv,
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
params, mapleaves, cpu, gpu, f32, f64
@reexport using NNlib

View File

@ -286,3 +286,103 @@ function Base.show(io::IO, l::InstanceNorm)
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end
"""
Group Normalization.
Known to improve the overall accuracy in case of classification and segmentation tasks.
GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
ϵ = 1f-5, momentum = 0.1f0)
chs is the numebr of channels, the channeld dimension of your input.
For an array of N dimensions, the (N-1)th index is the channel dimension.
G is the number of groups along which the statistics would be computed.
The number of groups must divide the number of channels for this to work.
Example:
```
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used
```
Link : https://arxiv.org/pdf/1803.08494.pdf
"""
mutable struct GroupNorm{F,V,W,N,T}
G::T # number of groups
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ²::W # moving std
ϵ::N
momentum::N
active::Bool
end
GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
GroupNorm(G, λ, param(initβ(chs)), param(initγ(chs)),
zeros(G,1), ones(G,1), ϵ, momentum, true)
function(gn::GroupNorm)(x)
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work")
(size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))")
dims = length(size(x))
groups = gn.G
channels = size(x, dims-1)
batches = size(x,dims)
channels_per_group = div(channels,groups)
affine_shape = ones(Int, dims)
# Output reshaped to (W,H...,C/G,G,N)
affine_shape[end-1] = channels
m = prod(size(x)[1:end-2]) * channels_per_group
γ = reshape(gn.γ, affine_shape...)
β = reshape(gn.β, affine_shape...)
if !gn.active
μ = reshape(gn.μ, affine_shape...)
σ² = reshape(gn.σ², affine_shape...)
ϵ = gn.ϵ
else
T = eltype(x)
og_shape = size(x)
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
μ = mean(y, dims = axes)
σ² = sum((y .- μ) .^ 2, dims = axes) ./ m
ϵ = data(convert(T, gn.ϵ))
# update moving mean/std
mtm = data(convert(T, gn.momentum))
gn.μ = (1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches))
gn.σ² = (1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches))
end
let λ = gn.λ
= (y .- μ) ./ sqrt.(σ² .+ ϵ)
# Reshape x̂
= reshape(,og_shape)
λ.(γ .* .+ β)
end
end
children(gn::GroupNorm) =
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum, gn.active)
mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
GroupNorm(gn,G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum, gn.active)
_testmode!(gn::GroupNorm, test) = (gn.active = !test)
function Base.show(io::IO, l::GroupNorm)
print(io, "GroupNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end