Made Requested Changes

This commit is contained in:
Shreyas 2019-03-28 01:33:04 +05:30
parent 671aed963e
commit 61c1fbd013
2 changed files with 11 additions and 6 deletions

View File

@ -289,17 +289,17 @@ end
"""
Group Normalization.
Known to improve the overall accuracy in case of classification and segmentation tasks.
This layer can outperform Batch-Normalization and Instance-Normalization.
GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
ϵ = 1f-5, momentum = 0.1f0)
chs is the numebr of channels, the channeld dimension of your input.
chs is the number of channels, the channel dimension of your input.
For an array of N dimensions, the (N-1)th index is the channel dimension.
G is the number of groups along which the statistics would be computed.
The number of groups must divide the number of channels for this to work.
The number of channels must be an integer multiple of the number of groups.
Example:
```
@ -349,16 +349,15 @@ function(gn::GroupNorm)(x)
γ = reshape(gn.γ, affine_shape...)
β = reshape(gn.β, affine_shape...)
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
if !gn.active
og_shape = size(x)
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
ϵ = gn.ϵ
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
else
T = eltype(x)
og_shape = size(x)
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
μ = mean(y, dims = axes)
σ² = mean((y .- μ) .^ 2, dims = axes)

View File

@ -1,5 +1,5 @@
using Flux: testmode!
using Flux.Tracker: data
using Flux.Tracker: data
@testset "Dropout" begin
x = [1.,2.,3.]
@ -304,4 +304,10 @@ end
@test IN(x) GN(x)
end
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
let BN = BatchNorm(4), GN = GroupNorm(4,1), sizes = (2,2,3,4,1),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test BN(x) GN(x)
end
end