Updated InstanceNorm and GroupNorm to avoid mutation
This commit is contained in:
parent
a645a86927
commit
faac0ff08b
@ -229,10 +229,8 @@ function (in::InstanceNorm)(x)
|
||||
dims = length(size(x))
|
||||
c = size(x, dims-1)
|
||||
bs = size(x, dims)
|
||||
affine_shape = ones(Int, dims)
|
||||
affine_shape[end-1] = c
|
||||
affine_shape[end] = bs
|
||||
m = prod(size(x)[1:end-2])
|
||||
affine_shape = ntuple(i->i == ndims(x) - 1 || i == ndims(x) ? size(x, i) : 1, ndims(x))
|
||||
m = div(prod(size(x)), c*bs)
|
||||
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
|
||||
|
||||
if !istraining()
|
||||
@ -246,11 +244,11 @@ function (in::InstanceNorm)(x)
|
||||
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
|
||||
μ = mean(x, dims = axes)
|
||||
σ² = mean((x .- μ) .^ 2, dims = axes)
|
||||
|
||||
S = eltype(in.μ)
|
||||
# update moving mean/std
|
||||
mtm = convert(T, in.momentum)
|
||||
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(μ, (c, bs)), dims = 2), dims=2)
|
||||
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(σ², (c, bs))), dims = 2), dims=2)
|
||||
mtm = in.momentum
|
||||
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* S.(reshape(μ, (c, bs))), dims = 2), dims=2)
|
||||
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* S.(reshape(σ², (c, bs)))), dims = 2), dims=2)
|
||||
end
|
||||
|
||||
let λ = in.λ
|
||||
@ -320,13 +318,10 @@ function(gn::GroupNorm)(x)
|
||||
channels = size(x, dims-1)
|
||||
batches = size(x,dims)
|
||||
channels_per_group = div(channels,groups)
|
||||
affine_shape = ones(Int, dims)
|
||||
affine_shape = ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x))
|
||||
|
||||
# Output reshaped to (W,H...,C/G,G,N)
|
||||
affine_shape[end-1] = channels
|
||||
|
||||
μ_affine_shape = ones(Int,dims + 1)
|
||||
μ_affine_shape[end-1] = groups
|
||||
μ_affine_shape = ntuple(i->i == ndims(x) ? groups : 1, ndims(x) + 1)
|
||||
|
||||
m = prod(size(x)[1:end-2]) * channels_per_group
|
||||
γ = reshape(gn.γ, affine_shape...)
|
||||
@ -345,12 +340,12 @@ function(gn::GroupNorm)(x)
|
||||
μ = mean(y, dims = axes)
|
||||
σ² = mean((y .- μ) .^ 2, dims = axes)
|
||||
|
||||
ϵ = data(convert(T, gn.ϵ))
|
||||
ϵ = convert(T, gn.ϵ)
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, gn.momentum))
|
||||
|
||||
gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches)),dims=2)
|
||||
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches)),dims=2)
|
||||
mtm = gn.momentum
|
||||
S = eltype(gn.μ)
|
||||
gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* S.(reshape(μ, (groups,batches))),dims=2)
|
||||
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* S.(reshape(σ², (groups,batches))),dims=2)
|
||||
end
|
||||
|
||||
let λ = gn.λ
|
||||
|
Loading…
Reference in New Issue
Block a user