diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 561b53df..5a8bdc56 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -229,10 +229,8 @@ function (in::InstanceNorm)(x) dims = length(size(x)) c = size(x, dims-1) bs = size(x, dims) - affine_shape = ones(Int, dims) - affine_shape[end-1] = c - affine_shape[end] = bs - m = prod(size(x)[1:end-2]) + affine_shape = ntuple(i->i == ndims(x) - 1 || i == ndims(x) ? size(x, i) : 1, ndims(x)) + m = div(prod(size(x)), c*bs) γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape) if !istraining() @@ -246,11 +244,11 @@ function (in::InstanceNorm)(x) axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes) μ = mean(x, dims = axes) σ² = mean((x .- μ) .^ 2, dims = axes) - + S = eltype(in.μ) # update moving mean/std - mtm = convert(T, in.momentum) - in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(μ, (c, bs)), dims = 2), dims=2) - in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(σ², (c, bs))), dims = 2), dims=2) + mtm = in.momentum + in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* S.(reshape(μ, (c, bs))), dims = 2), dims=2) + in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* S.(reshape(σ², (c, bs)))), dims = 2), dims=2) end let λ = in.λ @@ -320,13 +318,10 @@ function(gn::GroupNorm)(x) channels = size(x, dims-1) batches = size(x,dims) channels_per_group = div(channels,groups) - affine_shape = ones(Int, dims) + affine_shape = ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x)) # Output reshaped to (W,H...,C/G,G,N) - affine_shape[end-1] = channels - - μ_affine_shape = ones(Int,dims + 1) - μ_affine_shape[end-1] = groups + μ_affine_shape = ntuple(i->i == ndims(x) ? groups : 1, ndims(x) + 1) m = prod(size(x)[1:end-2]) * channels_per_group γ = reshape(gn.γ, affine_shape...) @@ -345,12 +340,12 @@ function(gn::GroupNorm)(x) μ = mean(y, dims = axes) σ² = mean((y .- μ) .^ 2, dims = axes) - ϵ = data(convert(T, gn.ϵ)) + ϵ = convert(T, gn.ϵ) # update moving mean/std - mtm = data(convert(T, gn.momentum)) - - gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches)),dims=2) - gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches)),dims=2) + mtm = gn.momentum + S = eltype(gn.μ) + gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* S.(reshape(μ, (groups,batches))),dims=2) + gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* S.(reshape(σ², (groups,batches))),dims=2) end let λ = gn.λ