fix normalisation layer params
This commit is contained in:
parent
6529dbcbe6
commit
b951377426
|
@ -37,14 +37,14 @@ function fmap(f, x; cache = IdDict())
|
|||
cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x)
|
||||
end
|
||||
|
||||
children(m) = functor(m)[1]
|
||||
trainable(m) = functor(m)[1]
|
||||
|
||||
params!(p::Params, x::AbstractArray{<:Real}, seen = IdSet()) = push!(p, x)
|
||||
|
||||
function params!(p::Params, x, seen = IdSet())
|
||||
x in seen && return
|
||||
push!(seen, x)
|
||||
for child in children(x)
|
||||
for child in trainable(x)
|
||||
params!(p, child, seen)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -134,6 +134,8 @@ BatchNorm(chs::Integer, λ = identity;
|
|||
BatchNorm(λ, initβ(chs), initγ(chs),
|
||||
zeros(chs), ones(chs), ϵ, momentum)
|
||||
|
||||
trainable(bn::BatchNorm) = (bn.β, bn.γ)
|
||||
|
||||
function (BN::BatchNorm)(x)
|
||||
size(x, ndims(x)-1) == length(BN.β) ||
|
||||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
||||
|
@ -220,6 +222,8 @@ InstanceNorm(chs::Integer, λ = identity;
|
|||
InstanceNorm(λ, initβ(chs), initγ(chs),
|
||||
zeros(chs), ones(chs), ϵ, momentum)
|
||||
|
||||
trainable(in::InstanceNorm) = (in.β, in.γ)
|
||||
|
||||
function (in::InstanceNorm)(x)
|
||||
size(x, ndims(x)-1) == length(in.β) ||
|
||||
error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))")
|
||||
|
@ -303,6 +307,8 @@ GroupNorm(chs::Integer, G::Integer, λ = identity;
|
|||
GroupNorm(G, λ, initβ(chs), initγ(chs),
|
||||
zeros(G,1), ones(G,1), ϵ, momentum)
|
||||
|
||||
trainable(gn::GroupNorm) = (gn.β, gn.γ)
|
||||
|
||||
function(gn::GroupNorm)(x)
|
||||
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
|
||||
ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work")
|
||||
|
|
|
@ -42,7 +42,7 @@ end
|
|||
let m = BatchNorm(2), x = [1.0 3.0 5.0;
|
||||
2.0 4.0 6.0]
|
||||
|
||||
@test_broken length(params(m)) == 2
|
||||
@test length(params(m)) == 2
|
||||
|
||||
@test m.β == [0, 0] # initβ(2)
|
||||
@test m.γ == [1, 1] # initγ(2)
|
||||
|
@ -113,7 +113,7 @@ end
|
|||
let m = InstanceNorm(2), sizes = (3, 2, 2),
|
||||
x = reshape(collect(1:prod(sizes)), sizes)
|
||||
|
||||
@test_broken length(params(m)) == 2
|
||||
@test length(params(m)) == 2
|
||||
x = Float64.(x)
|
||||
@test m.β == [0, 0] # initβ(2)
|
||||
@test m.γ == [1, 1] # initγ(2)
|
||||
|
@ -198,7 +198,7 @@ end
|
|||
let m = GroupNorm(4,2), sizes = (3,4,2),
|
||||
x = reshape(collect(1:prod(sizes)), sizes)
|
||||
|
||||
@test_broken length(params(m)) == 2
|
||||
@test length(params(m)) == 2
|
||||
x = Float64.(x)
|
||||
@test m.β == [0, 0, 0, 0] # initβ(32)
|
||||
@test m.γ == [1, 1, 1, 1] # initγ(32)
|
||||
|
|
|
@ -83,7 +83,6 @@ end
|
|||
|
||||
# Self-referential array. Just want params, no stack overflow pls.
|
||||
r = Any[nothing,m]
|
||||
Flux.children(a::Vector{Any}) = Tuple(a)
|
||||
r[1] = r
|
||||
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue