normalise test fixes

This commit is contained in:
Mike Innes 2019-09-10 16:19:55 +01:00
parent 877415be10
commit 250aef5a5a
1 changed files with 19 additions and 21 deletions

View File

@ -1,7 +1,8 @@
using Flux, Test
using Flux, Test, Statistics
using Zygote: forward
trainmode(f, x...) = forward(f, x...)[1]
trainmode(f) = (x...) -> trainmode(f, x...)
@testset "Dropout" begin
x = [1.,2.,3.]
@ -75,24 +76,23 @@ end
# with activation function
let m = BatchNorm(2, sigmoid), x = [1.0 3.0 5.0;
2.0 4.0 6.0]
y = trainmode(m, x)
y = m(x)
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
end
let m = BatchNorm(2), x = reshape(1:6, 3, 2, 1)
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
@test m(x) == y
end
let m = BatchNorm(2), x = reshape(1:12, 2, 3, 2, 1)
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
@test m(x) == y
end
let m = BatchNorm(2), x = reshape(1:24, 2, 2, 3, 2, 1)
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
@test m(x) == y
@ -154,13 +154,12 @@ end
affine_shape = collect(sizes)
affine_shape[1] = 1
y = trainmode(m, x)
y = m(x)
@test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
end
let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3),
x = reshape(collect(1:prod(sizes)), sizes)
let m = trainmode(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3),
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...)
@test m(x) == y
@ -168,16 +167,16 @@ end
# check that μ, σ², and the output are the correct size for higher rank tensors
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
x = reshape(collect(1:prod(sizes)), sizes)
y = m(x)
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
y = trainmode(m, x)
@test size(m.μ) == (sizes[end - 1], )
@test size(m.σ²) == (sizes[end - 1], )
@test size(y) == sizes
end
# show that instance norm is equal to batch norm when channel and batch dims are squashed
let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6),
x = reshape(collect(1:prod(sizes)), sizes)
let m_inorm = trainmode(InstanceNorm(2)), m_bnorm = trainmode(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6),
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
end
@ -251,15 +250,14 @@ end
og_shape = size(x)
y = trainmode(m, x)
y = m(x)
x_ = reshape(x,affine_shape...)
out = reshape(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ)),og_shape)
@test isapprox(y, out, atol = 1.0e-7)
end
let m = GroupNorm(2,2), sizes = (2, 4, 1, 2, 3),
x = reshape(collect(1:prod(sizes)), sizes)
let m = trainmode(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3),
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...)
@test m(x) == y
@ -267,22 +265,22 @@ end
# check that μ, σ², and the output are the correct size for higher rank tensors
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
x = reshape(collect(1:prod(sizes)), sizes)
y = m(x)
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
y = trainmode(m, x)
@test size(m.μ) == (m.G,1)
@test size(m.σ²) == (m.G,1)
@test size(y) == sizes
end
# show that group norm is the same as instance norm when the group size is the same as the number of channels
let IN = InstanceNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,5),
x = reshape(collect(1:prod(sizes)), sizes)
let IN = trainmode(InstanceNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,5),
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
@test IN(x) GN(x)
end
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
let BN = BatchNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,1),
x = reshape(collect(1:prod(sizes)), sizes)
let BN = trainmode(BatchNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,1),
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
@test BN(x) GN(x)
end