Made Requested Changes
This commit is contained in:
parent
671aed963e
commit
61c1fbd013
|
@ -289,17 +289,17 @@ end
|
|||
|
||||
"""
|
||||
Group Normalization.
|
||||
Known to improve the overall accuracy in case of classification and segmentation tasks.
|
||||
This layer can outperform Batch-Normalization and Instance-Normalization.
|
||||
|
||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
|
||||
ϵ = 1f-5, momentum = 0.1f0)
|
||||
|
||||
chs is the numebr of channels, the channeld dimension of your input.
|
||||
chs is the number of channels, the channel dimension of your input.
|
||||
For an array of N dimensions, the (N-1)th index is the channel dimension.
|
||||
|
||||
G is the number of groups along which the statistics would be computed.
|
||||
The number of groups must divide the number of channels for this to work.
|
||||
The number of channels must be an integer multiple of the number of groups.
|
||||
|
||||
Example:
|
||||
```
|
||||
|
@ -349,16 +349,15 @@ function(gn::GroupNorm)(x)
|
|||
γ = reshape(gn.γ, affine_shape...)
|
||||
β = reshape(gn.β, affine_shape...)
|
||||
|
||||
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
||||
if !gn.active
|
||||
og_shape = size(x)
|
||||
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||
ϵ = gn.ϵ
|
||||
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
||||
else
|
||||
T = eltype(x)
|
||||
og_shape = size(x)
|
||||
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
||||
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
|
||||
μ = mean(y, dims = axes)
|
||||
σ² = mean((y .- μ) .^ 2, dims = axes)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
using Flux: testmode!
|
||||
using Flux.Tracker: data
|
||||
using Flux.Tracker: data
|
||||
|
||||
@testset "Dropout" begin
|
||||
x = [1.,2.,3.]
|
||||
|
@ -304,4 +304,10 @@ end
|
|||
@test IN(x) ≈ GN(x)
|
||||
end
|
||||
|
||||
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
||||
let BN = BatchNorm(4), GN = GroupNorm(4,1), sizes = (2,2,3,4,1),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test BN(x) ≈ GN(x)
|
||||
end
|
||||
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue