Merge pull request #696 from shreyas-kowshik/group_norm_patch
Added GroupNorm Layer
This commit is contained in:
commit
7418a2d7d7
|
@ -6,10 +6,8 @@ using Base: tail
|
|||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, Maxout,
|
||||
RNN, LSTM, GRU,
|
||||
Conv, ConvTranspose, MaxPool, MeanPool, DepthwiseConv,
|
||||
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
|
||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
|
||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||
params, mapleaves, cpu, gpu, f32, f64
|
||||
|
||||
@reexport using NNlib
|
||||
|
|
|
@ -286,3 +286,109 @@ function Base.show(io::IO, l::InstanceNorm)
|
|||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
"""
|
||||
Group Normalization.
|
||||
This layer can outperform Batch-Normalization and Instance-Normalization.
|
||||
|
||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
|
||||
ϵ = 1f-5, momentum = 0.1f0)
|
||||
|
||||
``chs`` is the number of channels, the channel dimension of your input.
|
||||
For an array of N dimensions, the (N-1)th index is the channel dimension.
|
||||
|
||||
``G`` is the number of groups along which the statistics would be computed.
|
||||
The number of channels must be an integer multiple of the number of groups.
|
||||
|
||||
Example:
|
||||
```
|
||||
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
|
||||
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used
|
||||
```
|
||||
|
||||
Link : https://arxiv.org/pdf/1803.08494.pdf
|
||||
"""
|
||||
|
||||
mutable struct GroupNorm{F,V,W,N,T}
|
||||
G::T # number of groups
|
||||
λ::F # activation function
|
||||
β::V # bias
|
||||
γ::V # scale
|
||||
μ::W # moving mean
|
||||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Bool
|
||||
end
|
||||
|
||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
GroupNorm(G, λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(G,1), ones(G,1), ϵ, momentum, true)
|
||||
|
||||
function(gn::GroupNorm)(x)
|
||||
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
|
||||
ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work")
|
||||
(size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))")
|
||||
|
||||
dims = length(size(x))
|
||||
groups = gn.G
|
||||
channels = size(x, dims-1)
|
||||
batches = size(x,dims)
|
||||
channels_per_group = div(channels,groups)
|
||||
affine_shape = ones(Int, dims)
|
||||
|
||||
# Output reshaped to (W,H...,C/G,G,N)
|
||||
affine_shape[end-1] = channels
|
||||
|
||||
μ_affine_shape = ones(Int,dims + 1)
|
||||
μ_affine_shape[end-1] = groups
|
||||
|
||||
m = prod(size(x)[1:end-2]) * channels_per_group
|
||||
γ = reshape(gn.γ, affine_shape...)
|
||||
β = reshape(gn.β, affine_shape...)
|
||||
|
||||
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
||||
if !gn.active
|
||||
og_shape = size(x)
|
||||
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||
ϵ = gn.ϵ
|
||||
else
|
||||
T = eltype(x)
|
||||
og_shape = size(x)
|
||||
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
|
||||
μ = mean(y, dims = axes)
|
||||
σ² = mean((y .- μ) .^ 2, dims = axes)
|
||||
|
||||
ϵ = data(convert(T, gn.ϵ))
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, gn.momentum))
|
||||
|
||||
gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches)),dims=2)
|
||||
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches)),dims=2)
|
||||
end
|
||||
|
||||
let λ = gn.λ
|
||||
x̂ = (y .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||
|
||||
# Reshape x̂
|
||||
x̂ = reshape(x̂,og_shape)
|
||||
λ.(γ .* x̂ .+ β)
|
||||
end
|
||||
end
|
||||
|
||||
children(gn::GroupNorm) =
|
||||
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum, gn.active)
|
||||
|
||||
mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
|
||||
GroupNorm(gn,G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum, gn.active)
|
||||
|
||||
_testmode!(gn::GroupNorm, test) = (gn.active = !test)
|
||||
|
||||
function Base.show(io::IO, l::GroupNorm)
|
||||
print(io, "GroupNorm($(join(size(l.β), ", "))")
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
print(io, ")")
|
||||
end
|
||||
|
|
|
@ -200,3 +200,114 @@ end
|
|||
end
|
||||
|
||||
end
|
||||
|
||||
@testset "GroupNorm" begin
|
||||
# begin tests
|
||||
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
|
||||
|
||||
let m = GroupNorm(4,2), sizes = (3,4,2),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
|
||||
@test m.β.data == [0, 0, 0, 0] # initβ(32)
|
||||
@test m.γ.data == [1, 1, 1, 1] # initγ(32)
|
||||
|
||||
@test m.active
|
||||
|
||||
m(x)
|
||||
|
||||
#julia> x
|
||||
#[:, :, 1] =
|
||||
# 1.0 4.0 7.0 10.0
|
||||
# 2.0 5.0 8.0 11.0
|
||||
# 3.0 6.0 9.0 12.0
|
||||
#
|
||||
#[:, :, 2] =
|
||||
# 13.0 16.0 19.0 22.0
|
||||
# 14.0 17.0 20.0 23.0
|
||||
# 15.0 18.0 21.0 24.0
|
||||
#
|
||||
# μ will be
|
||||
# (1. + 2. + 3. + 4. + 5. + 6.) / 6 = 3.5
|
||||
# (7. + 8. + 9. + 10. + 11. + 12.) / 6 = 9.5
|
||||
#
|
||||
# (13. + 14. + 15. + 16. + 17. + 18.) / 6 = 15.5
|
||||
# (19. + 20. + 21. + 22. + 23. + 24.) / 6 = 21.5
|
||||
#
|
||||
# μ =
|
||||
# 3.5 15.5
|
||||
# 9.5 21.5
|
||||
#
|
||||
# ∴ update rule with momentum:
|
||||
# (1. - .1) * 0 + .1 * (3.5 + 15.5) / 2 = 0.95
|
||||
# (1. - .1) * 0 + .1 * (9.5 + 21.5) / 2 = 1.55
|
||||
@test m.μ ≈ [0.95, 1.55]
|
||||
|
||||
# julia> mean(var(reshape(x,3,2,2,2),dims=(1,2)).* .1,dims=2) .+ .9*1.
|
||||
# 2-element Array{Tracker.TrackedReal{Float64},1}:
|
||||
# 1.25
|
||||
# 1.25
|
||||
@test m.σ² ≈ mean(squeeze(var(reshape(x,3,2,2,2),dims=(1,2))).*.1,dims=2) .+ .9*1.
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
x′ = m(x).data
|
||||
println(x′[1])
|
||||
@test isapprox(x′[1], (1 - 0.95) / sqrt(1.25 + 1f-5), atol = 1.0e-5)
|
||||
end
|
||||
# with activation function
|
||||
let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
|
||||
μ_affine_shape = ones(Int,length(sizes) + 1)
|
||||
μ_affine_shape[end-1] = 2 # Number of groups
|
||||
|
||||
affine_shape = ones(Int,length(sizes) + 1)
|
||||
affine_shape[end-2] = 2 # Channels per group
|
||||
affine_shape[end-1] = 2 # Number of groups
|
||||
affine_shape[1] = sizes[1]
|
||||
affine_shape[end] = sizes[end]
|
||||
|
||||
og_shape = size(x)
|
||||
|
||||
@test m.active
|
||||
m(x)
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
y = m(x)
|
||||
x_ = reshape(x,affine_shape...)
|
||||
out = reshape(data(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ))),og_shape)
|
||||
@test isapprox(y, out, atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = GroupNorm(2,2), sizes = (2, 4, 1, 2, 3),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||
y = reshape(m(y), sizes...)
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = m(x)
|
||||
@test size(m.μ) == (m.G,1)
|
||||
@test size(m.σ²) == (m.G,1)
|
||||
@test size(y) == sizes
|
||||
end
|
||||
|
||||
# show that group norm is the same as instance norm when the group size is the same as the number of channels
|
||||
let IN = InstanceNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,5),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test IN(x) ≈ GN(x)
|
||||
end
|
||||
|
||||
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
||||
let BN = BatchNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,1),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test BN(x) ≈ GN(x)
|
||||
end
|
||||
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue