broken normalisation layer params

This commit is contained in:
Mike Innes 2019-09-19 14:12:11 +01:00
parent fc9db7ee74
commit f8d5d3b5fc

View File

@ -42,6 +42,8 @@ end
let m = BatchNorm(2), x = [1.0 3.0 5.0;
2.0 4.0 6.0]
@test_broken length(params(m)) == 2
@test m.β == [0, 0] # initβ(2)
@test m.γ == [1, 1] # initγ(2)
# initial m.σ is 1
@ -110,6 +112,8 @@ end
# begin tests
let m = InstanceNorm(2), sizes = (3, 2, 2),
x = reshape(collect(1:prod(sizes)), sizes)
@test_broken length(params(m)) == 2
x = Float64.(x)
@test m.β == [0, 0] # initβ(2)
@test m.γ == [1, 1] # initγ(2)
@ -193,6 +197,8 @@ end
let m = GroupNorm(4,2), sizes = (3,4,2),
x = reshape(collect(1:prod(sizes)), sizes)
@test_broken length(params(m)) == 2
x = Float64.(x)
@test m.β == [0, 0, 0, 0] # initβ(32)
@test m.γ == [1, 1, 1, 1] # initγ(32)