From c41f8910052ab4ee85374c339bd8651fd84b4597 Mon Sep 17 00:00:00 2001 From: David Pollack Date: Wed, 20 Feb 2019 14:51:55 +0100 Subject: [PATCH] changes based on the improved batchnorm in PR#633 --- src/layers/normalise.jl | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index eaa994b2..168f3363 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -185,6 +185,7 @@ m = Chain( softmax) ``` """ +expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) mutable struct InstanceNorm{F,V,W,N} λ::F # activation function β::V # bias @@ -207,7 +208,6 @@ function (IN::InstanceNorm)(x) ndims(x) > 2 || error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned") # these are repeated later on depending on the batch size - γ, β = IN.γ, IN.β dims = length(size(x)) c = size(x, dims-1) bs = size(x, dims) @@ -215,10 +215,12 @@ function (IN::InstanceNorm)(x) affine_shape[end-1] = c affine_shape[end] = bs m = prod(size(x)[1:end-2]) + γ, β = expand_inst(IN.γ, affine_shape), expand_inst(IN.β, affine_shape) if !IN.active - μ = reshape(repeat(IN.μ, outer=[bs]), affine_shape...) - σ² = reshape(repeat(IN.σ², outer=[bs]), affine_shape...) + μ = expand_inst(IN.μ, affine_shape) + σ² = expand_inst(IN.σ², affine_shape) + ϵ = IN.ϵ else T = eltype(x) @@ -229,14 +231,13 @@ function (IN::InstanceNorm)(x) # update moving mean/std mtm = data(convert(T, IN.momentum)) - IN.μ = reshape(mean((1 - mtm) .* repeat(IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), :) - IN.σ² = reshape(mean(((1 - mtm) .* repeat(IN.σ², outer=[1, bs]) .+ mtm .* reshape(data(σ²), (c, bs)) .* (m / (m - 1))), dims = 2), :) + IN.μ = reshape(mean(repeat((1 - mtm) .* IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), :) + IN.σ² = reshape(mean((repeat((1 - mtm) .* IN.σ², outer=[1, bs]) .+ reshape(data(σ²), (c, bs)) .* (mtm * m / (m - 1))), dims = 2), :) end let λ = IN.λ - temp = reshape(repeat(γ, outer=[bs]), affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ IN.ϵ)) - # This is intentionally not fused because of an extreme slowdown doing so - λ.(temp .+ reshape(repeat(β, outer=[bs]), affine_shape...)) + x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ) + λ.(γ .* x̂ .+ β) end end