fix normalisation layer params

This commit is contained in:
Mike Innes 2019-09-19 15:33:24 +01:00
parent 6529dbcbe6
commit b951377426
4 changed files with 11 additions and 6 deletions

View File

@ -37,14 +37,14 @@ function fmap(f, x; cache = IdDict())
cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x)
end
children(m) = functor(m)[1]
trainable(m) = functor(m)[1]
params!(p::Params, x::AbstractArray{<:Real}, seen = IdSet()) = push!(p, x)
function params!(p::Params, x, seen = IdSet())
x in seen && return
push!(seen, x)
for child in children(x)
for child in trainable(x)
params!(p, child, seen)
end
end

View File

@ -134,6 +134,8 @@ BatchNorm(chs::Integer, λ = identity;
BatchNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum)
trainable(bn::BatchNorm) = (bn.β, bn.γ)
function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) ||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
@ -220,6 +222,8 @@ InstanceNorm(chs::Integer, λ = identity;
InstanceNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum)
trainable(in::InstanceNorm) = (in.β, in.γ)
function (in::InstanceNorm)(x)
size(x, ndims(x)-1) == length(in.β) ||
error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))")
@ -303,6 +307,8 @@ GroupNorm(chs::Integer, G::Integer, λ = identity;
GroupNorm(G, λ, initβ(chs), initγ(chs),
zeros(G,1), ones(G,1), ϵ, momentum)
trainable(gn::GroupNorm) = (gn.β, gn.γ)
function(gn::GroupNorm)(x)
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work")

View File

@ -42,7 +42,7 @@ end
let m = BatchNorm(2), x = [1.0 3.0 5.0;
2.0 4.0 6.0]
@test_broken length(params(m)) == 2
@test length(params(m)) == 2
@test m.β == [0, 0] # initβ(2)
@test m.γ == [1, 1] # initγ(2)
@ -113,7 +113,7 @@ end
let m = InstanceNorm(2), sizes = (3, 2, 2),
x = reshape(collect(1:prod(sizes)), sizes)
@test_broken length(params(m)) == 2
@test length(params(m)) == 2
x = Float64.(x)
@test m.β == [0, 0] # initβ(2)
@test m.γ == [1, 1] # initγ(2)
@ -198,7 +198,7 @@ end
let m = GroupNorm(4,2), sizes = (3,4,2),
x = reshape(collect(1:prod(sizes)), sizes)
@test_broken length(params(m)) == 2
@test length(params(m)) == 2
x = Float64.(x)
@test m.β == [0, 0, 0, 0] # initβ(32)
@test m.γ == [1, 1, 1, 1] # initγ(32)

View File

@ -83,7 +83,6 @@ end
# Self-referential array. Just want params, no stack overflow pls.
r = Any[nothing,m]
Flux.children(a::Vector{Any}) = Tuple(a)
r[1] = r
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)]
end