instance normalization
This commit is contained in:
parent
3a4c6274fa
commit
129a708b6f
|
@ -7,7 +7,7 @@ using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
|||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
|
||||
DepthwiseConv, Dropout, LayerNorm, BatchNorm,
|
||||
DepthwiseConv, Dropout, LayerNorm, BatchNorm, InstanceNorm,
|
||||
params, mapleaves, cpu, gpu, f32, f64
|
||||
|
||||
@reexport using NNlib
|
||||
|
|
|
@ -155,3 +155,101 @@ function Base.show(io::IO, l::BatchNorm)
|
|||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
InstanceNorm(channels::Integer, σ = identity;
|
||||
initβ = zeros, initγ = ones,
|
||||
ϵ = 1e-8, momentum = .1)
|
||||
|
||||
Instance Normalization layer. The `channels` input should be the size of the
|
||||
channel dimension in your data (see below).
|
||||
|
||||
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
|
||||
a batch of feature vectors this is just the data dimension, for `WHCN` images
|
||||
it's the usual channel dimension.)
|
||||
|
||||
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
|
||||
shifts them to have a new mean and variance (corresponding to the learnable,
|
||||
per-channel `bias` and `scale` parameters).
|
||||
|
||||
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
|
||||
|
||||
Example:
|
||||
```julia
|
||||
m = Chain(
|
||||
Dense(28^2, 64),
|
||||
InstanceNorm(64, relu),
|
||||
Dense(64, 10),
|
||||
InstanceNorm(10),
|
||||
softmax)
|
||||
```
|
||||
"""
|
||||
mutable struct InstanceNorm{F,V,W,N}
|
||||
λ::F # activation function
|
||||
β::V # bias
|
||||
γ::V # scale
|
||||
μ::W # moving mean
|
||||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Bool
|
||||
end
|
||||
|
||||
InstanceNorm(chs::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||
|
||||
function (IN::InstanceNorm)(x)
|
||||
size(x, ndims(x)-1) == length(IN.β) ||
|
||||
error("InstanceNorm expected $(length(IN.β)) channels, got $(size(x, ndims(x)-1))")
|
||||
ndims(x) > 2 ||
|
||||
error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned")
|
||||
# these are repeated later on depending on the batch size
|
||||
γ, β = IN.γ, IN.β
|
||||
dims = length(size(x))
|
||||
c = size(x, dims-1)
|
||||
bs = size(x, dims)
|
||||
affine_shape = ones(Int, dims)
|
||||
affine_shape[end-1] = c
|
||||
affine_shape[end] = bs
|
||||
m = prod(size(x)[1:end-2])
|
||||
|
||||
if !IN.active
|
||||
μ = reshape(repeat(IN.μ, outer=[bs]), affine_shape...)
|
||||
σ² = reshape(repeat(IN.σ², outer=[bs]), affine_shape...)
|
||||
else
|
||||
T = eltype(x)
|
||||
|
||||
ϵ = data(convert(T, IN.ϵ))
|
||||
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
|
||||
μ = mean(x, dims = axes)
|
||||
σ² = mean((x .- μ) .^ 2, dims = axes)
|
||||
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, IN.momentum))
|
||||
IN.μ = reshape(mean((1 - mtm) .* repeat(IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), :)
|
||||
IN.σ² = reshape(mean(((1 - mtm) .* repeat(IN.σ², outer=[1, bs]) .+ mtm .* reshape(data(σ²), (c, bs)) .* (m / (m - 1))), dims = 2), :)
|
||||
end
|
||||
|
||||
let λ = IN.λ
|
||||
temp = reshape(repeat(γ, outer=[bs]), affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ IN.ϵ))
|
||||
# This is intentionally not fused because of an extreme slowdown doing so
|
||||
λ.(temp .+ reshape(repeat(β, outer=[bs]), affine_shape...))
|
||||
end
|
||||
end
|
||||
|
||||
children(IN::InstanceNorm) =
|
||||
(IN.λ, IN.β, IN.γ, IN.μ, IN.σ², IN.ϵ, IN.momentum, IN.active)
|
||||
|
||||
mapchildren(f, IN::InstanceNorm) = # e.g. mapchildren(cu, IN)
|
||||
InstanceNorm(IN.λ, f(IN.β), f(IN.γ), f(IN.μ), f(IN.σ²), IN.ϵ, IN.momentum, IN.active)
|
||||
|
||||
_testmode!(IN::InstanceNorm, test) = (IN.active = !test)
|
||||
|
||||
function Base.show(io::IO, l::InstanceNorm)
|
||||
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
print(io, ")")
|
||||
end
|
||||
|
|
|
@ -104,3 +104,84 @@ end
|
|||
@test (@allocated m(x)) < 100_000_000
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@testset "InstanceNorm" begin
|
||||
# helper functions
|
||||
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||
# begin tests
|
||||
let m = InstanceNorm(2), sizes = (3, 2, 2),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
|
||||
@test m.β.data == [0, 0] # initβ(2)
|
||||
@test m.γ.data == [1, 1] # initγ(2)
|
||||
|
||||
@test m.active
|
||||
|
||||
m(x)
|
||||
|
||||
#julia> x
|
||||
#[:, :, 1] =
|
||||
# 1.0 4.0
|
||||
# 2.0 5.0
|
||||
# 3.0 6.0
|
||||
#
|
||||
#[:, :, 2] =
|
||||
# 7.0 10.0
|
||||
# 8.0 11.0
|
||||
# 9.0 12.0
|
||||
#
|
||||
# μ will be
|
||||
# (1. + 2. + 3.) / 3 = 2.
|
||||
# (4. + 5. + 6.) / 3 = 5.
|
||||
#
|
||||
# (7. + 8. + 9.) / 3 = 8.
|
||||
# (10. + 11. + 12.) / 3 = 11.
|
||||
#
|
||||
# ∴ update rule with momentum:
|
||||
# (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5
|
||||
# (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8
|
||||
@test m.μ ≈ [0.5, 0.8]
|
||||
# momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq
|
||||
# julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||
# 2-element Array{Float64,1}:
|
||||
# 1.
|
||||
# 1.
|
||||
@test m.σ² ≈ reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
x′ = m(x).data
|
||||
@test isapprox(x′[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5)
|
||||
end
|
||||
# with activation function
|
||||
let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
|
||||
affine_shape = collect(sizes)
|
||||
affine_shape[1] = 1
|
||||
|
||||
@test m.active
|
||||
m(x)
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
y = m(x).data
|
||||
@test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||
y = reshape(m(y), sizes...)
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1);
|
||||
m(x)
|
||||
@test (@allocated m(x)) < 100_000_000
|
||||
end
|
||||
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue