normalise test fixes
This commit is contained in:
parent
877415be10
commit
250aef5a5a
|
@ -1,7 +1,8 @@
|
|||
using Flux, Test
|
||||
using Flux, Test, Statistics
|
||||
using Zygote: forward
|
||||
|
||||
trainmode(f, x...) = forward(f, x...)[1]
|
||||
trainmode(f) = (x...) -> trainmode(f, x...)
|
||||
|
||||
@testset "Dropout" begin
|
||||
x = [1.,2.,3.]
|
||||
|
@ -75,24 +76,23 @@ end
|
|||
# with activation function
|
||||
let m = BatchNorm(2, sigmoid), x = [1.0 3.0 5.0;
|
||||
2.0 4.0 6.0]
|
||||
y = trainmode(m, x)
|
||||
y = m(x)
|
||||
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = BatchNorm(2), x = reshape(1:6, 3, 2, 1)
|
||||
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
|
||||
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = BatchNorm(2), x = reshape(1:12, 2, 3, 2, 1)
|
||||
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = BatchNorm(2), x = reshape(1:24, 2, 2, 3, 2, 1)
|
||||
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
|
||||
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
||||
@test m(x) == y
|
||||
|
@ -154,13 +154,12 @@ end
|
|||
affine_shape = collect(sizes)
|
||||
affine_shape[1] = 1
|
||||
|
||||
y = trainmode(m, x)
|
||||
y = m(x)
|
||||
@test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3),
|
||||
x = reshape(collect(1:prod(sizes)), sizes)
|
||||
let m = trainmode(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||
y = reshape(m(y), sizes...)
|
||||
@test m(x) == y
|
||||
|
@ -168,16 +167,16 @@ end
|
|||
|
||||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = reshape(collect(1:prod(sizes)), sizes)
|
||||
y = m(x)
|
||||
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
||||
y = trainmode(m, x)
|
||||
@test size(m.μ) == (sizes[end - 1], )
|
||||
@test size(m.σ²) == (sizes[end - 1], )
|
||||
@test size(y) == sizes
|
||||
end
|
||||
|
||||
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
||||
let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = reshape(collect(1:prod(sizes)), sizes)
|
||||
let m_inorm = trainmode(InstanceNorm(2)), m_bnorm = trainmode(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
||||
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
||||
end
|
||||
|
||||
|
@ -251,15 +250,14 @@ end
|
|||
|
||||
og_shape = size(x)
|
||||
|
||||
y = trainmode(m, x)
|
||||
y = m(x)
|
||||
x_ = reshape(x,affine_shape...)
|
||||
out = reshape(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ)),og_shape)
|
||||
@test isapprox(y, out, atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = GroupNorm(2,2), sizes = (2, 4, 1, 2, 3),
|
||||
x = reshape(collect(1:prod(sizes)), sizes)
|
||||
let m = trainmode(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||
y = reshape(m(y), sizes...)
|
||||
@test m(x) == y
|
||||
|
@ -267,22 +265,22 @@ end
|
|||
|
||||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
|
||||
x = reshape(collect(1:prod(sizes)), sizes)
|
||||
y = m(x)
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = trainmode(m, x)
|
||||
@test size(m.μ) == (m.G,1)
|
||||
@test size(m.σ²) == (m.G,1)
|
||||
@test size(y) == sizes
|
||||
end
|
||||
|
||||
# show that group norm is the same as instance norm when the group size is the same as the number of channels
|
||||
let IN = InstanceNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,5),
|
||||
x = reshape(collect(1:prod(sizes)), sizes)
|
||||
let IN = trainmode(InstanceNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,5),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test IN(x) ≈ GN(x)
|
||||
end
|
||||
|
||||
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
||||
let BN = BatchNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,1),
|
||||
x = reshape(collect(1:prod(sizes)), sizes)
|
||||
let BN = trainmode(BatchNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,1),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test BN(x) ≈ GN(x)
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in New Issue