Made a few fixes. Added tests

This commit is contained in:
Shreyas 2019-03-28 00:51:31 +05:30
parent 8033dca0c3
commit 671aed963e
2 changed files with 118 additions and 6 deletions

View File

@ -342,26 +342,33 @@ function(gn::GroupNorm)(x)
# Output reshaped to (W,H...,C/G,G,N)
affine_shape[end-1] = channels
μ_affine_shape = ones(Int,dims + 1)
μ_affine_shape[end-1] = groups
m = prod(size(x)[1:end-2]) * channels_per_group
γ = reshape(gn.γ, affine_shape...)
β = reshape(gn.β, affine_shape...)
if !gn.active
μ = reshape(gn.μ, affine_shape...)
σ² = reshape(gn.σ², affine_shape...)
og_shape = size(x)
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
ϵ = gn.ϵ
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
else
T = eltype(x)
og_shape = size(x)
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
μ = mean(y, dims = axes)
σ² = sum((y .- μ) .^ 2, dims = axes) ./ m
σ² = mean((y .- μ) .^ 2, dims = axes)
ϵ = data(convert(T, gn.ϵ))
# update moving mean/std
mtm = data(convert(T, gn.momentum))
gn.μ = (1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches))
gn.σ² = (1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches))
gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches)),dims=2)
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches)),dims=2)
end
let λ = gn.λ

View File

@ -1,5 +1,5 @@
using Flux: testmode!
using Flux.Tracker: data
using Flux.Tracker: data
@testset "Dropout" begin
x = [1.,2.,3.]
@ -200,3 +200,108 @@ end
end
end
@testset "GroupNorm" begin
# begin tests
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
let m = GroupNorm(4,2), sizes = (3,4,2),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test m.β.data == [0, 0, 0, 0] # initβ(32)
@test m.γ.data == [1, 1, 1, 1] # initγ(32)
@test m.active
m(x)
#julia> x
#[:, :, 1] =
# 1.0 4.0 7.0 10.0
# 2.0 5.0 8.0 11.0
# 3.0 6.0 9.0 12.0
#
#[:, :, 2] =
# 13.0 16.0 19.0 22.0
# 14.0 17.0 20.0 23.0
# 15.0 18.0 21.0 24.0
#
# μ will be
# (1. + 2. + 3. + 4. + 5. + 6.) / 6 = 3.5
# (7. + 8. + 9. + 10. + 11. + 12.) / 6 = 9.5
#
# (13. + 14. + 15. + 16. + 17. + 18.) / 6 = 15.5
# (19. + 20. + 21. + 22. + 23. + 24.) / 6 = 21.5
#
# μ =
# 3.5 15.5
# 9.5 21.5
#
# ∴ update rule with momentum:
# (1. - .1) * 0 + .1 * (3.5 + 15.5) / 2 = 0.95
# (1. - .1) * 0 + .1 * (9.5 + 21.5) / 2 = 1.55
@test m.μ [0.95, 1.55]
# julia> mean(var(reshape(x,3,2,2,2),dims=(1,2)).* .1,dims=2) .+ .9*1.
# 2-element Array{Tracker.TrackedReal{Float64},1}:
# 1.25
# 1.25
@test m.σ² mean(squeeze(var(reshape(x,3,2,2,2),dims=(1,2))).*.1,dims=2) .+ .9*1.
testmode!(m)
@test !m.active
x = m(x).data
println(x[1])
@test isapprox(x[1], (1 - 0.95) / sqrt(1.25 + 1f-5), atol = 1.0e-5)
end
# with activation function
let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2),
x = param(reshape(collect(1:prod(sizes)), sizes))
μ_affine_shape = ones(Int,length(sizes) + 1)
μ_affine_shape[end-1] = 2 # Number of groups
affine_shape = ones(Int,length(sizes) + 1)
affine_shape[end-2] = 2 # Channels per group
affine_shape[end-1] = 2 # Number of groups
affine_shape[1] = sizes[1]
affine_shape[end] = sizes[end]
og_shape = size(x)
@test m.active
m(x)
testmode!(m)
@test !m.active
y = m(x)
x_ = reshape(x,affine_shape...)
out = reshape(data(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ))),og_shape)
@test isapprox(y, out, atol = 1.0e-7)
end
let m = GroupNorm(2,2), sizes = (2, 4, 1, 2, 3),
x = param(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...)
@test m(x) == y
end
# check that μ, σ², and the output are the correct size for higher rank tensors
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
x = param(reshape(collect(1:prod(sizes)), sizes))
y = m(x)
@test size(m.μ) == (m.G,1)
@test size(m.σ²) == (m.G,1)
@test size(y) == sizes
end
# show that group norm is the same as instance norm when the group size is the same as the number of channels
let IN = InstanceNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,5),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test IN(x) GN(x)
end
end