GroupNorm made to use istraining()

This commit is contained in:
thebhatman 2019-06-11 22:04:33 +05:30
parent a56cfb73c3
commit 11073dcd25

View File

@ -264,11 +264,11 @@ function Base.show(io::IO, l::InstanceNorm)
end
"""
Group Normalization.
Group Normalization.
This layer can outperform Batch-Normalization and Instance-Normalization.
GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
ϵ = 1f-5, momentum = 0.1f0)
``chs`` is the number of channels, the channel dimension of your input.
@ -280,7 +280,7 @@ The number of channels must be an integer multiple of the number of groups.
Example:
```
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used
```
Link : https://arxiv.org/pdf/1803.08494.pdf
@ -295,7 +295,6 @@ mutable struct GroupNorm{F,V,W,N,T}
σ²::W # moving std
ϵ::N
momentum::N
active::Bool
end
GroupNorm(chs::Integer, G::Integer, λ = identity;
@ -324,9 +323,9 @@ function(gn::GroupNorm)(x)
m = prod(size(x)[1:end-2]) * channels_per_group
γ = reshape(gn.γ, affine_shape...)
β = reshape(gn.β, affine_shape...)
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
if !gn.active
if !istraining()
og_shape = size(x)
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
@ -337,7 +336,7 @@ function(gn::GroupNorm)(x)
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
μ = mean(y, dims = axes)
σ² = mean((y .- μ) .^ 2, dims = axes)
ϵ = data(convert(T, gn.ϵ))
# update moving mean/std
mtm = data(convert(T, gn.momentum))
@ -349,7 +348,7 @@ function(gn::GroupNorm)(x)
let λ = gn.λ
= (y .- μ) ./ sqrt.(σ² .+ ϵ)
# Reshape x̂
# Reshape x̂
= reshape(,og_shape)
λ.(γ .* .+ β)
end