Made Requested Changes

This commit is contained in:
Shreyas 2019-03-28 01:33:04 +05:30
parent 671aed963e
commit 61c1fbd013
2 changed files with 11 additions and 6 deletions

View File

@ -289,17 +289,17 @@ end
""" """
Group Normalization. Group Normalization.
Known to improve the overall accuracy in case of classification and segmentation tasks. This layer can outperform Batch-Normalization and Instance-Normalization.
GroupNorm(chs::Integer, G::Integer, λ = identity; GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
ϵ = 1f-5, momentum = 0.1f0) ϵ = 1f-5, momentum = 0.1f0)
chs is the numebr of channels, the channeld dimension of your input. chs is the number of channels, the channel dimension of your input.
For an array of N dimensions, the (N-1)th index is the channel dimension. For an array of N dimensions, the (N-1)th index is the channel dimension.
G is the number of groups along which the statistics would be computed. G is the number of groups along which the statistics would be computed.
The number of groups must divide the number of channels for this to work. The number of channels must be an integer multiple of the number of groups.
Example: Example:
``` ```
@ -349,16 +349,15 @@ function(gn::GroupNorm)(x)
γ = reshape(gn.γ, affine_shape...) γ = reshape(gn.γ, affine_shape...)
β = reshape(gn.β, affine_shape...) β = reshape(gn.β, affine_shape...)
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
if !gn.active if !gn.active
og_shape = size(x) og_shape = size(x)
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1) μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1) σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
ϵ = gn.ϵ ϵ = gn.ϵ
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
else else
T = eltype(x) T = eltype(x)
og_shape = size(x) og_shape = size(x)
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis) axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
μ = mean(y, dims = axes) μ = mean(y, dims = axes)
σ² = mean((y .- μ) .^ 2, dims = axes) σ² = mean((y .- μ) .^ 2, dims = axes)

View File

@ -1,5 +1,5 @@
using Flux: testmode! using Flux: testmode!
using Flux.Tracker: data using Flux.Tracker: data
@testset "Dropout" begin @testset "Dropout" begin
x = [1.,2.,3.] x = [1.,2.,3.]
@ -304,4 +304,10 @@ end
@test IN(x) GN(x) @test IN(x) GN(x)
end end
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
let BN = BatchNorm(4), GN = GroupNorm(4,1), sizes = (2,2,3,4,1),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test BN(x) GN(x)
end
end end