Made Requested Changes
This commit is contained in:
parent
671aed963e
commit
61c1fbd013
@ -289,17 +289,17 @@ end
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Group Normalization.
|
Group Normalization.
|
||||||
Known to improve the overall accuracy in case of classification and segmentation tasks.
|
This layer can outperform Batch-Normalization and Instance-Normalization.
|
||||||
|
|
||||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
|
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
|
||||||
ϵ = 1f-5, momentum = 0.1f0)
|
ϵ = 1f-5, momentum = 0.1f0)
|
||||||
|
|
||||||
chs is the numebr of channels, the channeld dimension of your input.
|
chs is the number of channels, the channel dimension of your input.
|
||||||
For an array of N dimensions, the (N-1)th index is the channel dimension.
|
For an array of N dimensions, the (N-1)th index is the channel dimension.
|
||||||
|
|
||||||
G is the number of groups along which the statistics would be computed.
|
G is the number of groups along which the statistics would be computed.
|
||||||
The number of groups must divide the number of channels for this to work.
|
The number of channels must be an integer multiple of the number of groups.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```
|
```
|
||||||
@ -349,16 +349,15 @@ function(gn::GroupNorm)(x)
|
|||||||
γ = reshape(gn.γ, affine_shape...)
|
γ = reshape(gn.γ, affine_shape...)
|
||||||
β = reshape(gn.β, affine_shape...)
|
β = reshape(gn.β, affine_shape...)
|
||||||
|
|
||||||
|
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
||||||
if !gn.active
|
if !gn.active
|
||||||
og_shape = size(x)
|
og_shape = size(x)
|
||||||
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||||
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||||
ϵ = gn.ϵ
|
ϵ = gn.ϵ
|
||||||
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
|
||||||
else
|
else
|
||||||
T = eltype(x)
|
T = eltype(x)
|
||||||
og_shape = size(x)
|
og_shape = size(x)
|
||||||
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
|
||||||
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
|
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
|
||||||
μ = mean(y, dims = axes)
|
μ = mean(y, dims = axes)
|
||||||
σ² = mean((y .- μ) .^ 2, dims = axes)
|
σ² = mean((y .- μ) .^ 2, dims = axes)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
using Flux: testmode!
|
using Flux: testmode!
|
||||||
using Flux.Tracker: data
|
using Flux.Tracker: data
|
||||||
|
|
||||||
@testset "Dropout" begin
|
@testset "Dropout" begin
|
||||||
x = [1.,2.,3.]
|
x = [1.,2.,3.]
|
||||||
@ -304,4 +304,10 @@ end
|
|||||||
@test IN(x) ≈ GN(x)
|
@test IN(x) ≈ GN(x)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
||||||
|
let BN = BatchNorm(4), GN = GroupNorm(4,1), sizes = (2,2,3,4,1),
|
||||||
|
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||||
|
@test BN(x) ≈ GN(x)
|
||||||
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user