Merge pull request #634 from dhpollack/instancenorm
instance normalization
This commit is contained in:
commit
82578bfb0d
@ -7,7 +7,7 @@ using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
|
||||
DepthwiseConv, Dropout, LayerNorm, BatchNorm,
|
||||
DepthwiseConv, Dropout, LayerNorm, BatchNorm, InstanceNorm,
|
||||
params, mapleaves, cpu, gpu, f32, f64
|
||||
|
||||
@reexport using NNlib
|
||||
|
@ -155,3 +155,103 @@ function Base.show(io::IO, l::BatchNorm)
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
InstanceNorm(channels::Integer, σ = identity;
|
||||
initβ = zeros, initγ = ones,
|
||||
ϵ = 1e-8, momentum = .1)
|
||||
|
||||
Instance Normalization layer. The `channels` input should be the size of the
|
||||
channel dimension in your data (see below).
|
||||
|
||||
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
|
||||
a batch of feature vectors this is just the data dimension, for `WHCN` images
|
||||
it's the usual channel dimension.)
|
||||
|
||||
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
|
||||
shifts them to have a new mean and variance (corresponding to the learnable,
|
||||
per-channel `bias` and `scale` parameters).
|
||||
|
||||
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
|
||||
|
||||
Example:
|
||||
```julia
|
||||
m = Chain(
|
||||
Dense(28^2, 64),
|
||||
InstanceNorm(64, relu),
|
||||
Dense(64, 10),
|
||||
InstanceNorm(10),
|
||||
softmax)
|
||||
```
|
||||
"""
|
||||
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||
|
||||
mutable struct InstanceNorm{F,V,W,N}
|
||||
λ::F # activation function
|
||||
β::V # bias
|
||||
γ::V # scale
|
||||
μ::W # moving mean
|
||||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Bool
|
||||
end
|
||||
|
||||
InstanceNorm(chs::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||
|
||||
function (in::InstanceNorm)(x)
|
||||
size(x, ndims(x)-1) == length(in.β) ||
|
||||
error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))")
|
||||
ndims(x) > 2 ||
|
||||
error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned")
|
||||
# these are repeated later on depending on the batch size
|
||||
dims = length(size(x))
|
||||
c = size(x, dims-1)
|
||||
bs = size(x, dims)
|
||||
affine_shape = ones(Int, dims)
|
||||
affine_shape[end-1] = c
|
||||
affine_shape[end] = bs
|
||||
m = prod(size(x)[1:end-2])
|
||||
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
|
||||
|
||||
if !in.active
|
||||
μ = expand_inst(in.μ, affine_shape)
|
||||
σ² = expand_inst(in.σ², affine_shape)
|
||||
ϵ = in.ϵ
|
||||
else
|
||||
T = eltype(x)
|
||||
|
||||
ϵ = data(convert(T, in.ϵ))
|
||||
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
|
||||
μ = mean(x, dims = axes)
|
||||
σ² = mean((x .- μ) .^ 2, dims = axes)
|
||||
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, in.momentum))
|
||||
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2)
|
||||
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2)
|
||||
end
|
||||
|
||||
let λ = in.λ
|
||||
x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||
λ.(γ .* x̂ .+ β)
|
||||
end
|
||||
end
|
||||
|
||||
children(in::InstanceNorm) =
|
||||
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active)
|
||||
|
||||
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
|
||||
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active)
|
||||
|
||||
_testmode!(in::InstanceNorm, test) = (in.active = !test)
|
||||
|
||||
function Base.show(io::IO, l::InstanceNorm)
|
||||
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
print(io, ")")
|
||||
end
|
||||
|
@ -104,3 +104,99 @@ end
|
||||
@test (@allocated m(x)) < 100_000_000
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@testset "InstanceNorm" begin
|
||||
# helper functions
|
||||
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||
# begin tests
|
||||
let m = InstanceNorm(2), sizes = (3, 2, 2),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
|
||||
@test m.β.data == [0, 0] # initβ(2)
|
||||
@test m.γ.data == [1, 1] # initγ(2)
|
||||
|
||||
@test m.active
|
||||
|
||||
m(x)
|
||||
|
||||
#julia> x
|
||||
#[:, :, 1] =
|
||||
# 1.0 4.0
|
||||
# 2.0 5.0
|
||||
# 3.0 6.0
|
||||
#
|
||||
#[:, :, 2] =
|
||||
# 7.0 10.0
|
||||
# 8.0 11.0
|
||||
# 9.0 12.0
|
||||
#
|
||||
# μ will be
|
||||
# (1. + 2. + 3.) / 3 = 2.
|
||||
# (4. + 5. + 6.) / 3 = 5.
|
||||
#
|
||||
# (7. + 8. + 9.) / 3 = 8.
|
||||
# (10. + 11. + 12.) / 3 = 11.
|
||||
#
|
||||
# ∴ update rule with momentum:
|
||||
# (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5
|
||||
# (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8
|
||||
@test m.μ ≈ [0.5, 0.8]
|
||||
# momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq
|
||||
# julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||
# 2-element Array{Float64,1}:
|
||||
# 1.
|
||||
# 1.
|
||||
@test m.σ² ≈ reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
x′ = m(x).data
|
||||
@test isapprox(x′[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5)
|
||||
end
|
||||
# with activation function
|
||||
let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
|
||||
affine_shape = collect(sizes)
|
||||
affine_shape[1] = 1
|
||||
|
||||
@test m.active
|
||||
m(x)
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
y = m(x).data
|
||||
@test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||
y = reshape(m(y), sizes...)
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = m(x)
|
||||
@test size(m.μ) == (sizes[end - 1], )
|
||||
@test size(m.σ²) == (sizes[end - 1], )
|
||||
@test size(y) == sizes
|
||||
end
|
||||
|
||||
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
||||
let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
||||
end
|
||||
|
||||
let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1);
|
||||
m(x)
|
||||
@test (@allocated m(x)) < 100_000_000
|
||||
end
|
||||
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user