Added testmode! functionality back to normalization layers.
This commit is contained in:
parent
88b0c65d72
commit
7c12af065a
|
@ -11,7 +11,7 @@ export gradient
|
|||
|
||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||
SkipConnection, params, fmap, cpu, gpu, f32, f64
|
||||
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!
|
||||
|
||||
include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
|
|
|
@ -2,6 +2,23 @@ istraining() = false
|
|||
|
||||
@adjoint istraining() = true, _ -> nothing
|
||||
|
||||
_isactive(m) = isnothing(m.active) ? istraining() : m.active
|
||||
# @adjoint _isactive(m) = _isactive(m), Δ -> nothing
|
||||
|
||||
"""
|
||||
testmode!(m, mode = :auto)
|
||||
|
||||
Set a layer or model's test mode (see below).
|
||||
Using `:auto` mode will treat any gradient computation as training.
|
||||
|
||||
Possible values include:
|
||||
- `false` for training
|
||||
- `true` for testing
|
||||
- `:auto` or `nothing` for Flux to detect the mode automatically
|
||||
"""
|
||||
testmode!(m, mode) = nothing
|
||||
testmode!(m::Chain, mode = :auto) = map(x -> testmode!(x, mode), m.layers)
|
||||
|
||||
_dropout_shape(s, ::Colon) = size(s)
|
||||
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...)
|
||||
|
||||
|
@ -22,18 +39,27 @@ A Dropout layer. For each input, either sets that input to `0` (with probability
|
|||
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted
|
||||
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
|
||||
used as a regularisation, i.e. it reduces overfitting during training. see also [`dropout`](@ref).
|
||||
|
||||
Does nothing to the input once [`testmode!`](@ref) is false.
|
||||
"""
|
||||
mutable struct Dropout{F,D}
|
||||
p::F
|
||||
dims::D
|
||||
active::Union{Bool, Nothing}
|
||||
end
|
||||
|
||||
function Dropout(p; dims = :)
|
||||
@assert 0 ≤ p ≤ 1
|
||||
Dropout{typeof(p),typeof(dims)}(p, dims)
|
||||
Dropout{typeof(p),typeof(dims)}(p, dims, nothing)
|
||||
end
|
||||
|
||||
(a::Dropout)(x) = dropout(x, a.p; dims = a.dims)
|
||||
function (a::Dropout)(x)
|
||||
_isactive(a) || return x
|
||||
return dropout(x, a.p; dims = a.dims)
|
||||
end
|
||||
|
||||
testmode!(m::Dropout, mode = :auto) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
||||
|
||||
function Base.show(io::IO, d::Dropout)
|
||||
print(io, "Dropout(", d.p)
|
||||
|
@ -46,17 +72,20 @@ end
|
|||
A dropout layer. It is used in Self-Normalizing Neural Networks.
|
||||
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
|
||||
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
|
||||
|
||||
Does nothing to the input once [`testmode!`](@ref) is false.
|
||||
"""
|
||||
mutable struct AlphaDropout{F}
|
||||
p::F
|
||||
function AlphaDropout(p)
|
||||
active::Union{Bool, Nothing}
|
||||
function AlphaDropout(p, active = nothing)
|
||||
@assert 0 ≤ p ≤ 1
|
||||
new{typeof(p)}(p)
|
||||
new{typeof(p)}(p, active)
|
||||
end
|
||||
end
|
||||
|
||||
function (a::AlphaDropout)(x)
|
||||
istraining() || return x
|
||||
_isactive(a) || return x
|
||||
λ = eltype(x)(1.0507009873554804934193349852946)
|
||||
α = eltype(x)(1.6732632423543772848170429916717)
|
||||
α1 = eltype(x)(-λ*α)
|
||||
|
@ -68,6 +97,9 @@ function (a::AlphaDropout)(x)
|
|||
return x
|
||||
end
|
||||
|
||||
testmode!(m::AlphaDropout, mode = :auto) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
||||
|
||||
"""
|
||||
LayerNorm(h::Integer)
|
||||
|
||||
|
@ -106,6 +138,8 @@ it's the usual channel dimension.)
|
|||
shifts them to have a new mean and variance (corresponding to the learnable,
|
||||
per-channel `bias` and `scale` parameters).
|
||||
|
||||
Use [`testmode!`](@ref) during inference.
|
||||
|
||||
See [Batch Normalization: Accelerating Deep Network Training by Reducing
|
||||
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
|
||||
|
||||
|
@ -127,12 +161,13 @@ mutable struct BatchNorm{F,V,W,N}
|
|||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Union{Bool, Nothing}
|
||||
end
|
||||
|
||||
BatchNorm(chs::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
BatchNorm(λ, initβ(chs), initγ(chs),
|
||||
zeros(chs), ones(chs), ϵ, momentum)
|
||||
zeros(chs), ones(chs), ϵ, momentum, nothing)
|
||||
|
||||
trainable(bn::BatchNorm) = (bn.β, bn.γ)
|
||||
|
||||
|
@ -145,7 +180,7 @@ function (BN::BatchNorm)(x)
|
|||
m = div(prod(size(x)), channels)
|
||||
γ = reshape(BN.γ, affine_shape...)
|
||||
β = reshape(BN.β, affine_shape...)
|
||||
if !istraining()
|
||||
if !_isactive(BN)
|
||||
μ = reshape(BN.μ, affine_shape...)
|
||||
σ² = reshape(BN.σ², affine_shape...)
|
||||
ϵ = BN.ϵ
|
||||
|
@ -170,6 +205,9 @@ end
|
|||
|
||||
@functor BatchNorm
|
||||
|
||||
testmode!(m::BatchNorm, mode = :auto) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
||||
|
||||
function Base.show(io::IO, l::BatchNorm)
|
||||
print(io, "BatchNorm($(join(size(l.β), ", "))")
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
|
@ -193,6 +231,8 @@ it's the usual channel dimension.)
|
|||
shifts them to have a new mean and variance (corresponding to the learnable,
|
||||
per-channel `bias` and `scale` parameters).
|
||||
|
||||
Use [`testmode!`](@ref) during inference.
|
||||
|
||||
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
|
||||
|
||||
Example:
|
||||
|
@ -215,12 +255,13 @@ mutable struct InstanceNorm{F,V,W,N}
|
|||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Union{Bool, Nothing}
|
||||
end
|
||||
|
||||
InstanceNorm(chs::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
InstanceNorm(λ, initβ(chs), initγ(chs),
|
||||
zeros(chs), ones(chs), ϵ, momentum)
|
||||
zeros(chs), ones(chs), ϵ, momentum, nothing)
|
||||
|
||||
trainable(in::InstanceNorm) = (in.β, in.γ)
|
||||
|
||||
|
@ -237,7 +278,7 @@ function (in::InstanceNorm)(x)
|
|||
m = div(prod(size(x)), c*bs)
|
||||
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
|
||||
|
||||
if !istraining()
|
||||
if !_isactive(in)
|
||||
μ = expand_inst(in.μ, affine_shape)
|
||||
σ² = expand_inst(in.σ², affine_shape)
|
||||
ϵ = in.ϵ
|
||||
|
@ -263,6 +304,9 @@ end
|
|||
|
||||
@functor InstanceNorm
|
||||
|
||||
testmode!(m::InstanceNorm, mode = :auto) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
||||
|
||||
function Base.show(io::IO, l::InstanceNorm)
|
||||
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
|
@ -283,6 +327,8 @@ For an array of N dimensions, the (N-1)th index is the channel dimension.
|
|||
``G`` is the number of groups along which the statistics would be computed.
|
||||
The number of channels must be an integer multiple of the number of groups.
|
||||
|
||||
Use [`testmode!`](@ref) during inference.
|
||||
|
||||
Example:
|
||||
```
|
||||
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
|
||||
|
@ -300,12 +346,13 @@ mutable struct GroupNorm{F,V,W,N,T}
|
|||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Union{Bool, Nothing}
|
||||
end
|
||||
|
||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
GroupNorm(G, λ, initβ(chs), initγ(chs),
|
||||
zeros(G,1), ones(G,1), ϵ, momentum)
|
||||
zeros(G,1), ones(G,1), ϵ, momentum, nothing)
|
||||
|
||||
trainable(gn::GroupNorm) = (gn.β, gn.γ)
|
||||
|
||||
|
@ -329,7 +376,7 @@ function(gn::GroupNorm)(x)
|
|||
β = reshape(gn.β, affine_shape...)
|
||||
|
||||
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
||||
if !istraining()
|
||||
if !_isactive(gn)
|
||||
og_shape = size(x)
|
||||
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||
|
@ -360,6 +407,9 @@ end
|
|||
|
||||
@functor GroupNorm
|
||||
|
||||
testmode!(m::GroupNorm, mode = :auto) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
||||
|
||||
function Base.show(io::IO, l::GroupNorm)
|
||||
print(io, "GroupNorm($(join(size(l.β), ", "))")
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
|
|
|
@ -1,30 +1,33 @@
|
|||
using Flux, Test, Statistics
|
||||
using Zygote: pullback
|
||||
|
||||
trainmode(f, x...) = pullback(f, x...)[1]
|
||||
trainmode(f) = (x...) -> trainmode(f, x...)
|
||||
evalwgrad(f, x...) = pullback(f, x...)[1]
|
||||
trainmode(f) = (testmode!(f, false); f)
|
||||
|
||||
@testset "Dropout" begin
|
||||
x = [1.,2.,3.]
|
||||
@test x == Dropout(0.1)(x)
|
||||
@test x == trainmode(Dropout(0), x)
|
||||
@test zero(x) == trainmode(Dropout(1), x)
|
||||
@test x == evalwgrad(Dropout(0), x)
|
||||
@test zero(x) == evalwgrad(Dropout(1), x)
|
||||
|
||||
x = rand(100)
|
||||
m = Dropout(0.9)
|
||||
y = trainmode(m, x)
|
||||
y = evalwgrad(m, x)
|
||||
@test count(a->a==0, y) > 50
|
||||
y = m(x)
|
||||
testmode!(m, true)
|
||||
y = evalwgrad(m, x) # should override istraining
|
||||
@test count(a->a==0, y) == 0
|
||||
y = trainmode(m, x)
|
||||
testmode!(m, false)
|
||||
y = evalwgrad(m, x)
|
||||
@test count(a->a==0, y) > 50
|
||||
|
||||
x = rand(Float32, 100)
|
||||
m = Chain(Dense(100,100),
|
||||
Dropout(0.9))
|
||||
y = trainmode(m, x)
|
||||
y = evalwgrad(m, x)
|
||||
@test count(a->a == 0, y) > 50
|
||||
y = m(x)
|
||||
testmode!(m, true)
|
||||
y = evalwgrad(m, x) # should override istraining
|
||||
@test count(a->a == 0, y) == 0
|
||||
|
||||
x = rand(100, 50)
|
||||
|
@ -49,7 +52,7 @@ end
|
|||
# initial m.σ is 1
|
||||
# initial m.μ is 0
|
||||
|
||||
y = trainmode(m, x)
|
||||
y = evalwgrad(m, x)
|
||||
@test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5)
|
||||
# julia> x
|
||||
# 2×3 Array{Float64,2}:
|
||||
|
@ -117,7 +120,7 @@ end
|
|||
x = Float64.(x)
|
||||
@test m.β == [0, 0] # initβ(2)
|
||||
@test m.γ == [1, 1] # initγ(2)
|
||||
y = trainmode(m, x)
|
||||
y = evalwgrad(m, x)
|
||||
|
||||
#julia> x
|
||||
#[:, :, 1] =
|
||||
|
@ -172,7 +175,7 @@ end
|
|||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
||||
y = trainmode(m, x)
|
||||
y = evalwgrad(m, x)
|
||||
@test size(m.μ) == (sizes[end - 1], )
|
||||
@test size(m.σ²) == (sizes[end - 1], )
|
||||
@test size(y) == sizes
|
||||
|
@ -204,7 +207,7 @@ if VERSION >= v"1.1"
|
|||
@test m.β == [0, 0, 0, 0] # initβ(32)
|
||||
@test m.γ == [1, 1, 1, 1] # initγ(32)
|
||||
|
||||
y = trainmode(m, x)
|
||||
y = evalwgrad(m, x)
|
||||
|
||||
#julia> x
|
||||
#[:, :, 1] =
|
||||
|
@ -273,7 +276,7 @@ if VERSION >= v"1.1"
|
|||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = trainmode(m, x)
|
||||
y = evalwgrad(m, x)
|
||||
@test size(m.μ) == (m.G,1)
|
||||
@test size(m.σ²) == (m.G,1)
|
||||
@test size(y) == sizes
|
||||
|
|
Loading…
Reference in New Issue