Flux.jl/src/layers/normalise.jl

417 lines
13 KiB
Julia
Raw Normal View History

2019-03-08 12:56:19 +00:00
istraining() = false
2017-10-26 10:46:12 +00:00
2019-03-08 12:56:19 +00:00
@adjoint istraining() = true, _ -> nothing
2017-10-26 10:46:12 +00:00
_isactive(m) = isnothing(m.active) ? istraining() : m.active
2019-07-12 14:38:28 +00:00
_dropout_shape(s, ::Colon) = size(s)
_dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...)
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
2020-02-29 10:14:48 +00:00
"""
dropout(x, p; dims = :)
The dropout function. For each input, either sets that input to `0` (with probability
`p`) or scales it by `1 / (1 - p)`. `dims` specifies the unbroadcasted dimensions,
e.g. `dims=1` applies dropout along columns and `dims=2` along rows.
This is used as a regularisation, i.e. it reduces overfitting during training.
See also the [`Dropout`](@ref) layer.
2020-02-29 10:14:48 +00:00
"""
2019-07-12 14:38:28 +00:00
dropout(x, p; dims = :) = x
@adjoint function dropout(x, p; dims = :)
y = rand!(similar(x, _dropout_shape(x, dims)))
y .= _dropout_kernel.(y, p, 1 - p)
return x .* y, Δ -> (Δ .* y, nothing)
end
2017-10-26 10:46:12 +00:00
"""
2019-05-10 15:45:50 +00:00
Dropout(p, dims = :)
2018-11-08 13:42:38 +00:00
2020-04-04 17:43:28 +00:00
Dropout layer. In the forward pass, apply the [`Flux.dropout`](@ref) function on the input.
Does nothing to the input once [`Flux.testmode!`](@ref) is `true`.
2017-10-26 10:46:12 +00:00
"""
2019-07-12 14:38:28 +00:00
mutable struct Dropout{F,D}
2017-10-26 10:46:12 +00:00
p::F
2019-07-12 14:38:28 +00:00
dims::D
active::Union{Bool, Nothing}
2019-06-12 17:46:15 +00:00
end
2020-03-03 18:05:03 +00:00
# TODO: deprecate in v0.11
Dropout(p, dims) = Dropout(p, dims, nothing)
2019-06-12 17:46:15 +00:00
function Dropout(p; dims = :)
@assert 0 p 1
Dropout{typeof(p),typeof(dims)}(p, dims, nothing)
end
function (a::Dropout)(x)
_isactive(a) || return x
return dropout(x, a.p; dims = a.dims)
2017-10-26 10:46:12 +00:00
end
testmode!(m::Dropout, mode = true) =
2020-02-29 22:09:59 +00:00
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
2019-06-12 17:09:30 +00:00
2019-07-24 15:20:39 +00:00
function Base.show(io::IO, d::Dropout)
print(io, "Dropout(", d.p)
d.dims != (:) && print(io, ", dims = $(repr(d.dims))")
print(io, ")")
end
"""
AlphaDropout(p)
A dropout layer. Used in
[Self-Normalizing Neural Networks](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf).
The AlphaDropout layer ensures that mean and variance of activations
remain the same as before.
Does nothing to the input once [`testmode!`](@ref) is true.
"""
mutable struct AlphaDropout{F}
2019-03-05 10:58:05 +00:00
p::F
active::Union{Bool, Nothing}
function AlphaDropout(p, active = nothing)
2019-03-08 12:56:19 +00:00
@assert 0 p 1
new{typeof(p)}(p, active)
2019-03-08 12:56:19 +00:00
end
end
2019-09-06 11:04:19 +00:00
function (a::AlphaDropout)(x)
_isactive(a) || return x
2019-03-07 18:12:38 +00:00
λ = eltype(x)(1.0507009873554804934193349852946)
α = eltype(x)(1.6732632423543772848170429916717)
α1 = eltype(x)(-λ*α)
2019-03-05 10:48:50 +00:00
noise = randn(eltype(x), size(x))
2019-09-06 11:04:19 +00:00
x = @. x*(noise > (1 - a.p)) + α1 * (noise < (1 - a.p))
A = (a.p + a.p * (1 - a.p) * α1 ^ 2)^0.5
B = -A * α1 * (1 - a.p)
x = @. A * x + B
return x
end
testmode!(m::AlphaDropout, mode = true) =
2020-02-29 22:09:59 +00:00
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
2017-10-23 11:53:07 +00:00
"""
LayerNorm(h::Integer)
2018-11-08 13:42:38 +00:00
2017-10-23 11:53:07 +00:00
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
used with recurrent hidden states of size `h`. Normalises the mean and standard
deviation of each input before applying a per-neuron gain/bias.
2017-10-23 11:53:07 +00:00
"""
struct LayerNorm{T}
diag::Diagonal{T}
end
LayerNorm(h::Integer) =
LayerNorm(Diagonal(h))
2019-09-19 14:53:31 +00:00
@functor LayerNorm
2017-10-23 11:53:07 +00:00
(a::LayerNorm)(x) = a.diag(normalise(x))
function Base.show(io::IO, l::LayerNorm)
print(io, "LayerNorm(", length(l.diag.α), ")")
end
2017-12-08 19:29:49 +00:00
"""
2018-04-15 19:04:42 +00:00
BatchNorm(channels::Integer, σ = identity;
initβ = zeros, initγ = ones,
ϵ = 1e-8, momentum = .1)
2018-11-08 13:42:38 +00:00
[Batch Normalization](https://arxiv.org/pdf/1502.03167.pdf) layer.
`channels` should be the size of the channel dimension in your data (see below).
2018-11-08 13:42:38 +00:00
2018-04-15 19:04:42 +00:00
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
a batch of feature vectors this is just the data dimension, for `WHCN` images
it's the usual channel dimension.)
2018-11-08 13:42:38 +00:00
2018-04-15 19:04:42 +00:00
`BatchNorm` computes the mean and variance for each each `W×H×1×N` slice and
shifts them to have a new mean and variance (corresponding to the learnable,
per-channel `bias` and `scale` parameters).
2018-11-08 13:42:38 +00:00
Use [`testmode!`](@ref) during inference.
# Examples
```julia
2017-12-08 19:34:34 +00:00
m = Chain(
Dense(28^2, 64),
2018-04-15 19:04:42 +00:00
BatchNorm(64, relu),
Dense(64, 10),
BatchNorm(10),
softmax)
```
"""
2018-09-11 10:51:55 +00:00
mutable struct BatchNorm{F,V,W,N}
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ²::W # moving std
ϵ::N
momentum::N
active::Union{Bool, Nothing}
end
2020-03-03 18:05:03 +00:00
# TODO: deprecate in v0.11
BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) = BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
2018-09-11 10:51:55 +00:00
BatchNorm(chs::Integer, λ = identity;
2019-01-24 13:13:39 +00:00
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
2019-03-08 12:13:58 +00:00
BatchNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum, nothing)
2019-09-19 14:33:24 +00:00
trainable(bn::BatchNorm) = (bn.β, bn.γ)
function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) ||
2018-04-15 19:29:25 +00:00
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
2018-09-11 10:28:17 +00:00
dims = length(size(x))
channels = size(x, dims-1)
2019-07-12 15:09:42 +00:00
affine_shape = ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x))
2019-07-16 12:22:55 +00:00
m = div(prod(size(x)), channels)
2019-02-20 08:17:31 +00:00
γ = reshape(BN.γ, affine_shape...)
β = reshape(BN.β, affine_shape...)
if !_isactive(BN)
μ = reshape(BN.μ, affine_shape...)
2018-09-11 10:51:55 +00:00
σ² = reshape(BN.σ², affine_shape...)
2019-02-20 08:17:31 +00:00
ϵ = BN.ϵ
else
2018-09-11 10:28:17 +00:00
T = eltype(x)
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
2018-09-11 10:28:17 +00:00
μ = mean(x, dims = axes)
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
2019-03-08 12:13:58 +00:00
ϵ = convert(T, BN.ϵ)
# update moving mean/std
2019-07-12 15:09:42 +00:00
mtm = BN.momentum
S = eltype(BN.μ)
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* S.(reshape(μ, :))
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* S.(reshape(σ², :))
end
2018-07-15 15:49:41 +00:00
2018-09-11 10:28:17 +00:00
let λ = BN.λ
2019-02-20 08:17:31 +00:00
= (x .- μ) ./ sqrt.(σ² .+ ϵ)
λ.(γ .* .+ β)
2018-09-11 10:28:17 +00:00
end
end
2019-09-19 14:22:11 +00:00
@functor BatchNorm
testmode!(m::BatchNorm, mode = true) =
2020-02-29 22:09:59 +00:00
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end
2019-02-20 13:01:05 +00:00
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
2019-02-20 13:01:05 +00:00
mutable struct InstanceNorm{F,V,W,N}
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ²::W # moving std
ϵ::N
momentum::N
active::Union{Bool, Nothing}
end
# TODO: deprecate in v0.11
2019-02-20 13:01:05 +00:00
"""
InstanceNorm(channels::Integer, σ = identity;
initβ = zeros, initγ = ones,
ϵ = 1e-8, momentum = .1)
[Instance Normalization](https://arxiv.org/abs/1607.08022) layer.
`channels` should be the size of the channel dimension in your data (see below).
2019-02-20 13:01:05 +00:00
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
a batch of feature vectors this is just the data dimension, for `WHCN` images
it's the usual channel dimension.)
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
shifts them to have a new mean and variance (corresponding to the learnable,
per-channel `bias` and `scale` parameters).
Use [`testmode!`](@ref) during inference.
# Examples
2019-02-20 13:01:05 +00:00
```julia
m = Chain(
Dense(28^2, 64),
InstanceNorm(64, relu),
Dense(64, 10),
InstanceNorm(10),
softmax)
```
"""
InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) = InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
2019-02-20 13:01:05 +00:00
InstanceNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
2019-03-08 12:13:58 +00:00
InstanceNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum, nothing)
2019-02-20 13:01:05 +00:00
2019-09-19 14:33:24 +00:00
trainable(in::InstanceNorm) = (in.β, in.γ)
2019-03-07 08:44:55 +00:00
function (in::InstanceNorm)(x)
size(x, ndims(x)-1) == length(in.β) ||
error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))")
2019-02-20 13:01:05 +00:00
ndims(x) > 2 ||
error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned")
# these are repeated later on depending on the batch size
dims = length(size(x))
c = size(x, dims-1)
bs = size(x, dims)
affine_shape = ntuple(i->i == ndims(x) - 1 || i == ndims(x) ? size(x, i) : 1, ndims(x))
m = div(prod(size(x)), c*bs)
2019-03-07 08:44:55 +00:00
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
2019-02-20 13:01:05 +00:00
if !_isactive(in)
2019-03-07 08:44:55 +00:00
μ = expand_inst(in.μ, affine_shape)
σ² = expand_inst(in.σ², affine_shape)
ϵ = in.ϵ
2019-02-20 13:01:05 +00:00
else
T = eltype(x)
2019-03-08 12:13:58 +00:00
ϵ = convert(T, in.ϵ)
2019-02-20 13:01:05 +00:00
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
μ = mean(x, dims = axes)
σ² = mean((x .- μ) .^ 2, dims = axes)
S = eltype(in.μ)
2019-02-20 13:01:05 +00:00
# update moving mean/std
mtm = in.momentum
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* S.(reshape(μ, (c, bs))), dims = 2), dims=2)
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* S.(reshape(σ², (c, bs)))), dims = 2), dims=2)
2019-02-20 13:01:05 +00:00
end
2019-03-07 08:44:55 +00:00
let λ = in.λ
= (x .- μ) ./ sqrt.(σ² .+ ϵ)
λ.(γ .* .+ β)
2019-02-20 13:01:05 +00:00
end
end
2019-09-19 14:22:11 +00:00
@functor InstanceNorm
2019-02-20 13:01:05 +00:00
testmode!(m::InstanceNorm, mode = true) =
2020-02-29 22:09:59 +00:00
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
2019-02-20 13:01:05 +00:00
function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end
2019-03-26 16:12:49 +00:00
"""
GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
ϵ = 1f-5, momentum = 0.1f0)
2019-03-26 16:12:49 +00:00
[Group Normalization](https://arxiv.org/pdf/1803.08494.pdf) layer.
This layer can outperform Batch Normalization and Instance Normalization.
2019-03-26 16:12:49 +00:00
`chs` is the number of channels, the channel dimension of your input.
For an array of N dimensions, the `N-1`th index is the channel dimension.
2019-03-26 16:12:49 +00:00
`G` is the number of groups along which the statistics are computed.
2019-03-27 20:03:04 +00:00
The number of channels must be an integer multiple of the number of groups.
2019-03-26 16:12:49 +00:00
Use [`testmode!`](@ref) during inference.
# Examples
```julia
2019-03-26 16:12:49 +00:00
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
GroupNorm(32,16))
# 32 channels, 16 groups (G = 16), thus 2 channels per group used
2019-03-26 16:12:49 +00:00
```
"""
mutable struct GroupNorm{F,V,W,N,T}
G::T # number of groups
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ²::W # moving std
ϵ::N
momentum::N
active::Union{Bool, Nothing}
2019-03-26 16:12:49 +00:00
end
2020-03-03 18:05:03 +00:00
# TODO: deprecate in v0.11
GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum) = GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, nothing)
2019-03-26 16:12:49 +00:00
GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
2019-08-19 13:46:21 +00:00
GroupNorm(G, λ, initβ(chs), initγ(chs),
zeros(G,1), ones(G,1), ϵ, momentum, nothing)
2019-03-26 16:12:49 +00:00
2019-09-19 14:33:24 +00:00
trainable(gn::GroupNorm) = (gn.β, gn.γ)
2019-03-26 16:12:49 +00:00
function(gn::GroupNorm)(x)
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work")
(size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))")
dims = length(size(x))
groups = gn.G
channels = size(x, dims-1)
batches = size(x,dims)
channels_per_group = div(channels,groups)
affine_shape = ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x))
2019-03-26 16:12:49 +00:00
# Output reshaped to (W,H...,C/G,G,N)
μ_affine_shape = ntuple(i->i == ndims(x) ? groups : 1, ndims(x) + 1)
2019-03-27 19:21:31 +00:00
2019-03-26 16:12:49 +00:00
m = prod(size(x)[1:end-2]) * channels_per_group
γ = reshape(gn.γ, affine_shape...)
β = reshape(gn.β, affine_shape...)
2019-06-11 16:34:33 +00:00
2019-03-27 20:03:04 +00:00
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
if !_isactive(gn)
2019-03-27 19:21:31 +00:00
og_shape = size(x)
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
2019-03-26 16:12:49 +00:00
ϵ = gn.ϵ
else
T = eltype(x)
og_shape = size(x)
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
μ = mean(y, dims = axes)
2019-03-27 19:21:31 +00:00
σ² = mean((y .- μ) .^ 2, dims = axes)
2019-06-11 16:34:33 +00:00
ϵ = convert(T, gn.ϵ)
2019-03-26 16:12:49 +00:00
# update moving mean/std
mtm = gn.momentum
S = eltype(gn.μ)
gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* S.(reshape(μ, (groups,batches))),dims=2)
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* S.(reshape(σ², (groups,batches))),dims=2)
2019-03-26 16:12:49 +00:00
end
let λ = gn.λ
= (y .- μ) ./ sqrt.(σ² .+ ϵ)
2019-06-11 16:34:33 +00:00
# Reshape x̂
2019-03-26 16:12:49 +00:00
= reshape(,og_shape)
λ.(γ .* .+ β)
end
end
2019-09-19 14:22:11 +00:00
@functor GroupNorm
2019-03-26 16:12:49 +00:00
testmode!(m::GroupNorm, mode = true) =
2020-02-29 22:09:59 +00:00
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
2019-03-26 16:12:49 +00:00
function Base.show(io::IO, l::GroupNorm)
print(io, "GroupNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end