broken normalisation layer params

This commit is contained in:
Mike Innes 2019-09-19 14:12:11 +01:00
parent fc9db7ee74
commit f8d5d3b5fc
1 changed files with 8 additions and 2 deletions

View File

@ -42,6 +42,8 @@ end
let m = BatchNorm(2), x = [1.0 3.0 5.0;
2.0 4.0 6.0]
@test_broken length(params(m)) == 2
@test m.β == [0, 0] # initβ(2)
@test m.γ == [1, 1] # initγ(2)
# initial m.σ is 1
@ -109,7 +111,9 @@ end
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
# begin tests
let m = InstanceNorm(2), sizes = (3, 2, 2),
x = reshape(collect(1:prod(sizes)), sizes)
x = reshape(collect(1:prod(sizes)), sizes)
@test_broken length(params(m)) == 2
x = Float64.(x)
@test m.β == [0, 0] # initβ(2)
@test m.γ == [1, 1] # initγ(2)
@ -192,7 +196,9 @@ end
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
let m = GroupNorm(4,2), sizes = (3,4,2),
x = reshape(collect(1:prod(sizes)), sizes)
x = reshape(collect(1:prod(sizes)), sizes)
@test_broken length(params(m)) == 2
x = Float64.(x)
@test m.β == [0, 0, 0, 0] # initβ(32)
@test m.γ == [1, 1, 1, 1] # initγ(32)