instance normalization

This commit is contained in:
David Pollack 2019-02-20 14:01:05 +01:00
parent 3a4c6274fa
commit 129a708b6f
3 changed files with 180 additions and 1 deletions

View File

@ -7,7 +7,7 @@ using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, LayerNorm, BatchNorm,
DepthwiseConv, Dropout, LayerNorm, BatchNorm, InstanceNorm,
params, mapleaves, cpu, gpu, f32, f64
@reexport using NNlib

View File

@ -155,3 +155,101 @@ function Base.show(io::IO, l::BatchNorm)
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end
"""
InstanceNorm(channels::Integer, σ = identity;
initβ = zeros, initγ = ones,
ϵ = 1e-8, momentum = .1)
Instance Normalization layer. The `channels` input should be the size of the
channel dimension in your data (see below).
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
a batch of feature vectors this is just the data dimension, for `WHCN` images
it's the usual channel dimension.)
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
shifts them to have a new mean and variance (corresponding to the learnable,
per-channel `bias` and `scale` parameters).
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
Example:
```julia
m = Chain(
Dense(28^2, 64),
InstanceNorm(64, relu),
Dense(64, 10),
InstanceNorm(10),
softmax)
```
"""
mutable struct InstanceNorm{F,V,W,N}
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ²::W # moving std
ϵ::N
momentum::N
active::Bool
end
InstanceNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)),
zeros(chs), ones(chs), ϵ, momentum, true)
function (IN::InstanceNorm)(x)
size(x, ndims(x)-1) == length(IN.β) ||
error("InstanceNorm expected $(length(IN.β)) channels, got $(size(x, ndims(x)-1))")
ndims(x) > 2 ||
error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned")
# these are repeated later on depending on the batch size
γ, β = IN.γ, IN.β
dims = length(size(x))
c = size(x, dims-1)
bs = size(x, dims)
affine_shape = ones(Int, dims)
affine_shape[end-1] = c
affine_shape[end] = bs
m = prod(size(x)[1:end-2])
if !IN.active
μ = reshape(repeat(IN.μ, outer=[bs]), affine_shape...)
σ² = reshape(repeat(IN.σ², outer=[bs]), affine_shape...)
else
T = eltype(x)
ϵ = data(convert(T, IN.ϵ))
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
μ = mean(x, dims = axes)
σ² = mean((x .- μ) .^ 2, dims = axes)
# update moving mean/std
mtm = data(convert(T, IN.momentum))
IN.μ = reshape(mean((1 - mtm) .* repeat(IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), :)
IN.σ² = reshape(mean(((1 - mtm) .* repeat(IN.σ², outer=[1, bs]) .+ mtm .* reshape(data(σ²), (c, bs)) .* (m / (m - 1))), dims = 2), :)
end
let λ = IN.λ
temp = reshape(repeat(γ, outer=[bs]), affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ IN.ϵ))
# This is intentionally not fused because of an extreme slowdown doing so
λ.(temp .+ reshape(repeat(β, outer=[bs]), affine_shape...))
end
end
children(IN::InstanceNorm) =
(IN.λ, IN.β, IN.γ, IN.μ, IN.σ², IN.ϵ, IN.momentum, IN.active)
mapchildren(f, IN::InstanceNorm) = # e.g. mapchildren(cu, IN)
InstanceNorm(IN.λ, f(IN.β), f(IN.γ), f(IN.μ), f(IN.σ²), IN.ϵ, IN.momentum, IN.active)
_testmode!(IN::InstanceNorm, test) = (IN.active = !test)
function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end

View File

@ -104,3 +104,84 @@ end
@test (@allocated m(x)) < 100_000_000
end
end
@testset "InstanceNorm" begin
# helper functions
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
# begin tests
let m = InstanceNorm(2), sizes = (3, 2, 2),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test m.β.data == [0, 0] # initβ(2)
@test m.γ.data == [1, 1] # initγ(2)
@test m.active
m(x)
#julia> x
#[:, :, 1] =
# 1.0 4.0
# 2.0 5.0
# 3.0 6.0
#
#[:, :, 2] =
# 7.0 10.0
# 8.0 11.0
# 9.0 12.0
#
# μ will be
# (1. + 2. + 3.) / 3 = 2.
# (4. + 5. + 6.) / 3 = 5.
#
# (7. + 8. + 9.) / 3 = 8.
# (10. + 11. + 12.) / 3 = 11.
#
# ∴ update rule with momentum:
# (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5
# (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8
@test m.μ [0.5, 0.8]
# momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq
# julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
# 2-element Array{Float64,1}:
# 1.
# 1.
@test m.σ² reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
testmode!(m)
@test !m.active
x = m(x).data
@test isapprox(x[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5)
end
# with activation function
let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
x = param(reshape(collect(1:prod(sizes)), sizes))
affine_shape = collect(sizes)
affine_shape[1] = 1
@test m.active
m(x)
testmode!(m)
@test !m.active
y = m(x).data
@test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7)
end
let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3),
x = param(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...)
@test m(x) == y
end
let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1);
m(x)
@test (@allocated m(x)) < 100_000_000
end
end