Merge #1044
1044: Add testmode! back for normalization layers r=CarloLucibello a=darsnack Fixed #909 I added `testmode!(m, mode)` back to Flux as per v0.9. Now the `mode` can be `false`, `true`, or `:auto`/`nothing` with the default being `:auto` for newly constructed layers. In `:auto` mode, the `istraining()` functions added in v0.10 are used to determine whether we are evaluating within an AD trace or not. Also plan on adding a doc section in an additional commit. Co-authored-by: Kyle Daruwalla <daruwalla@wisc.edu>
This commit is contained in:
commit
069d228693
|
@ -54,6 +54,15 @@ LayerNorm
|
|||
GroupNorm
|
||||
```
|
||||
|
||||
### Testmode
|
||||
|
||||
Many normalisation layers behave differently under training and inference (testing). By default, Flux will automatically determine when a layer evaluation is part of training or inference. Still, depending on your use case, it may be helpful to manually specify when these layers should be treated as being trained or not. For this, Flux provides `testmode!`. When called on a model (e.g. a layer or chain of layers), this function will place the model into the mode specified.
|
||||
|
||||
```@docs
|
||||
testmode!
|
||||
trainmode!
|
||||
```
|
||||
|
||||
## Cost Functions
|
||||
```@docs
|
||||
Flux.mse
|
||||
|
|
|
@ -12,7 +12,7 @@ export gradient
|
|||
|
||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||
SkipConnection, params, fmap, cpu, gpu, f32, f64
|
||||
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode!
|
||||
|
||||
include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
|
|
|
@ -39,6 +39,38 @@ end
|
|||
|
||||
trainable(m) = functor(m)[1]
|
||||
|
||||
"""
|
||||
testmode!(m, mode = true)
|
||||
|
||||
Set a layer or model's test mode (see below).
|
||||
Using `:auto` mode will treat any gradient computation as training.
|
||||
|
||||
_Note_: if you manually set a model into test mode, you need to manually place
|
||||
it back into train mode during training phase.
|
||||
|
||||
Possible values include:
|
||||
- `false` for training
|
||||
- `true` for testing
|
||||
- `:auto` or `nothing` for Flux to detect the mode automatically
|
||||
"""
|
||||
testmode!(m, mode = true) = m
|
||||
|
||||
"""
|
||||
trainmode!(m, mode = true)
|
||||
|
||||
Set a layer of model's train mode (see below).
|
||||
Symmetric to [`testmode!`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)).
|
||||
|
||||
_Note_: if you manually set a model into train mode, you need to manually place
|
||||
it into test mode during testing phase.
|
||||
|
||||
Possible values include:
|
||||
- `true` for training
|
||||
- `false` for testing
|
||||
- `:auto` or `nothing` for Flux to detect the mode automatically
|
||||
"""
|
||||
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)
|
||||
|
||||
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
|
||||
|
||||
function params!(p::Params, x, seen = IdSet())
|
||||
|
|
|
@ -33,6 +33,8 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
|
|||
|
||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||
|
||||
testmode!(m::Chain, mode = true) = (map(x -> testmode!(x, mode), m.layers); m)
|
||||
|
||||
function Base.show(io::IO, c::Chain)
|
||||
print(io, "Chain(")
|
||||
join(io, c.layers, ", ")
|
||||
|
|
|
@ -2,6 +2,8 @@ istraining() = false
|
|||
|
||||
@adjoint istraining() = true, _ -> nothing
|
||||
|
||||
_isactive(m) = isnothing(m.active) ? istraining() : m.active
|
||||
|
||||
_dropout_shape(s, ::Colon) = size(s)
|
||||
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...)
|
||||
|
||||
|
@ -29,18 +31,27 @@ end
|
|||
Dropout(p, dims = :)
|
||||
|
||||
A Dropout layer. In the forward pass, applies the [`dropout`](@ref) function on the input.
|
||||
|
||||
Does nothing to the input once [`testmode!`](@ref) is false.
|
||||
"""
|
||||
mutable struct Dropout{F,D}
|
||||
p::F
|
||||
dims::D
|
||||
active::Union{Bool, Nothing}
|
||||
end
|
||||
|
||||
function Dropout(p; dims = :)
|
||||
@assert 0 ≤ p ≤ 1
|
||||
Dropout{typeof(p),typeof(dims)}(p, dims)
|
||||
Dropout{typeof(p),typeof(dims)}(p, dims, nothing)
|
||||
end
|
||||
|
||||
(a::Dropout)(x) = dropout(x, a.p; dims = a.dims)
|
||||
function (a::Dropout)(x)
|
||||
_isactive(a) || return x
|
||||
return dropout(x, a.p; dims = a.dims)
|
||||
end
|
||||
|
||||
testmode!(m::Dropout, mode = true) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||
|
||||
function Base.show(io::IO, d::Dropout)
|
||||
print(io, "Dropout(", d.p)
|
||||
|
@ -54,17 +65,20 @@ end
|
|||
A dropout layer. It is used in Self-Normalizing Neural Networks.
|
||||
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
|
||||
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
|
||||
|
||||
Does nothing to the input once [`testmode!`](@ref) is false.
|
||||
"""
|
||||
mutable struct AlphaDropout{F}
|
||||
p::F
|
||||
function AlphaDropout(p)
|
||||
active::Union{Bool, Nothing}
|
||||
function AlphaDropout(p, active = nothing)
|
||||
@assert 0 ≤ p ≤ 1
|
||||
new{typeof(p)}(p)
|
||||
new{typeof(p)}(p, active)
|
||||
end
|
||||
end
|
||||
|
||||
function (a::AlphaDropout)(x)
|
||||
istraining() || return x
|
||||
_isactive(a) || return x
|
||||
λ = eltype(x)(1.0507009873554804934193349852946)
|
||||
α = eltype(x)(1.6732632423543772848170429916717)
|
||||
α1 = eltype(x)(-λ*α)
|
||||
|
@ -76,6 +90,9 @@ function (a::AlphaDropout)(x)
|
|||
return x
|
||||
end
|
||||
|
||||
testmode!(m::AlphaDropout, mode = true) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||
|
||||
"""
|
||||
LayerNorm(h::Integer)
|
||||
|
||||
|
@ -114,6 +131,8 @@ it's the usual channel dimension.)
|
|||
shifts them to have a new mean and variance (corresponding to the learnable,
|
||||
per-channel `bias` and `scale` parameters).
|
||||
|
||||
Use [`testmode!`](@ref) during inference.
|
||||
|
||||
See [Batch Normalization: Accelerating Deep Network Training by Reducing
|
||||
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
|
||||
|
||||
|
@ -135,12 +154,13 @@ mutable struct BatchNorm{F,V,W,N}
|
|||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Union{Bool, Nothing}
|
||||
end
|
||||
|
||||
BatchNorm(chs::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
BatchNorm(λ, initβ(chs), initγ(chs),
|
||||
zeros(chs), ones(chs), ϵ, momentum)
|
||||
zeros(chs), ones(chs), ϵ, momentum, nothing)
|
||||
|
||||
trainable(bn::BatchNorm) = (bn.β, bn.γ)
|
||||
|
||||
|
@ -153,7 +173,7 @@ function (BN::BatchNorm)(x)
|
|||
m = div(prod(size(x)), channels)
|
||||
γ = reshape(BN.γ, affine_shape...)
|
||||
β = reshape(BN.β, affine_shape...)
|
||||
if !istraining()
|
||||
if !_isactive(BN)
|
||||
μ = reshape(BN.μ, affine_shape...)
|
||||
σ² = reshape(BN.σ², affine_shape...)
|
||||
ϵ = BN.ϵ
|
||||
|
@ -178,6 +198,9 @@ end
|
|||
|
||||
@functor BatchNorm
|
||||
|
||||
testmode!(m::BatchNorm, mode = true) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||
|
||||
function Base.show(io::IO, l::BatchNorm)
|
||||
print(io, "BatchNorm($(join(size(l.β), ", "))")
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
|
@ -201,6 +224,8 @@ it's the usual channel dimension.)
|
|||
shifts them to have a new mean and variance (corresponding to the learnable,
|
||||
per-channel `bias` and `scale` parameters).
|
||||
|
||||
Use [`testmode!`](@ref) during inference.
|
||||
|
||||
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
|
||||
|
||||
Example:
|
||||
|
@ -223,12 +248,13 @@ mutable struct InstanceNorm{F,V,W,N}
|
|||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Union{Bool, Nothing}
|
||||
end
|
||||
|
||||
InstanceNorm(chs::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
InstanceNorm(λ, initβ(chs), initγ(chs),
|
||||
zeros(chs), ones(chs), ϵ, momentum)
|
||||
zeros(chs), ones(chs), ϵ, momentum, nothing)
|
||||
|
||||
trainable(in::InstanceNorm) = (in.β, in.γ)
|
||||
|
||||
|
@ -245,7 +271,7 @@ function (in::InstanceNorm)(x)
|
|||
m = div(prod(size(x)), c*bs)
|
||||
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
|
||||
|
||||
if !istraining()
|
||||
if !_isactive(in)
|
||||
μ = expand_inst(in.μ, affine_shape)
|
||||
σ² = expand_inst(in.σ², affine_shape)
|
||||
ϵ = in.ϵ
|
||||
|
@ -271,6 +297,9 @@ end
|
|||
|
||||
@functor InstanceNorm
|
||||
|
||||
testmode!(m::InstanceNorm, mode = true) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||
|
||||
function Base.show(io::IO, l::InstanceNorm)
|
||||
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
|
@ -291,6 +320,8 @@ For an array of N dimensions, the (N-1)th index is the channel dimension.
|
|||
``G`` is the number of groups along which the statistics would be computed.
|
||||
The number of channels must be an integer multiple of the number of groups.
|
||||
|
||||
Use [`testmode!`](@ref) during inference.
|
||||
|
||||
Example:
|
||||
```
|
||||
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
|
||||
|
@ -308,12 +339,13 @@ mutable struct GroupNorm{F,V,W,N,T}
|
|||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Union{Bool, Nothing}
|
||||
end
|
||||
|
||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
GroupNorm(G, λ, initβ(chs), initγ(chs),
|
||||
zeros(G,1), ones(G,1), ϵ, momentum)
|
||||
zeros(G,1), ones(G,1), ϵ, momentum, nothing)
|
||||
|
||||
trainable(gn::GroupNorm) = (gn.β, gn.γ)
|
||||
|
||||
|
@ -337,7 +369,7 @@ function(gn::GroupNorm)(x)
|
|||
β = reshape(gn.β, affine_shape...)
|
||||
|
||||
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
||||
if !istraining()
|
||||
if !_isactive(gn)
|
||||
og_shape = size(x)
|
||||
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||
|
@ -368,6 +400,9 @@ end
|
|||
|
||||
@functor GroupNorm
|
||||
|
||||
testmode!(m::GroupNorm, mode = true) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||
|
||||
function Base.show(io::IO, l::GroupNorm)
|
||||
print(io, "GroupNorm($(join(size(l.β), ", "))")
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
|
|
|
@ -1,30 +1,32 @@
|
|||
using Flux, Test, Statistics
|
||||
using Zygote: pullback
|
||||
|
||||
trainmode(f, x...) = pullback(f, x...)[1]
|
||||
trainmode(f) = (x...) -> trainmode(f, x...)
|
||||
evalwgrad(f, x...) = pullback(f, x...)[1]
|
||||
|
||||
@testset "Dropout" begin
|
||||
x = [1.,2.,3.]
|
||||
@test x == Dropout(0.1)(x)
|
||||
@test x == trainmode(Dropout(0), x)
|
||||
@test zero(x) == trainmode(Dropout(1), x)
|
||||
@test x == evalwgrad(Dropout(0), x)
|
||||
@test zero(x) == evalwgrad(Dropout(1), x)
|
||||
|
||||
x = rand(100)
|
||||
m = Dropout(0.9)
|
||||
y = trainmode(m, x)
|
||||
y = evalwgrad(m, x)
|
||||
@test count(a->a==0, y) > 50
|
||||
y = m(x)
|
||||
testmode!(m, true)
|
||||
y = evalwgrad(m, x) # should override istraining
|
||||
@test count(a->a==0, y) == 0
|
||||
y = trainmode(m, x)
|
||||
testmode!(m, false)
|
||||
y = evalwgrad(m, x)
|
||||
@test count(a->a==0, y) > 50
|
||||
|
||||
x = rand(Float32, 100)
|
||||
m = Chain(Dense(100,100),
|
||||
Dropout(0.9))
|
||||
y = trainmode(m, x)
|
||||
y = evalwgrad(m, x)
|
||||
@test count(a->a == 0, y) > 50
|
||||
y = m(x)
|
||||
testmode!(m, true)
|
||||
y = evalwgrad(m, x) # should override istraining
|
||||
@test count(a->a == 0, y) == 0
|
||||
|
||||
x = rand(100, 50)
|
||||
|
@ -49,7 +51,7 @@ end
|
|||
# initial m.σ is 1
|
||||
# initial m.μ is 0
|
||||
|
||||
y = trainmode(m, x)
|
||||
y = evalwgrad(m, x)
|
||||
@test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5)
|
||||
# julia> x
|
||||
# 2×3 Array{Float64,2}:
|
||||
|
@ -82,19 +84,19 @@ end
|
|||
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
|
||||
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
|
||||
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
|
||||
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
|
||||
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
|
||||
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
||||
@test m(x) == y
|
||||
|
@ -117,7 +119,7 @@ end
|
|||
x = Float64.(x)
|
||||
@test m.β == [0, 0] # initβ(2)
|
||||
@test m.γ == [1, 1] # initγ(2)
|
||||
y = trainmode(m, x)
|
||||
y = evalwgrad(m, x)
|
||||
|
||||
#julia> x
|
||||
#[:, :, 1] =
|
||||
|
@ -162,7 +164,7 @@ end
|
|||
@test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = trainmode(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3),
|
||||
let m = trainmode!(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||
y = reshape(m(y), sizes...)
|
||||
|
@ -172,14 +174,14 @@ end
|
|||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
||||
y = trainmode(m, x)
|
||||
y = evalwgrad(m, x)
|
||||
@test size(m.μ) == (sizes[end - 1], )
|
||||
@test size(m.σ²) == (sizes[end - 1], )
|
||||
@test size(y) == sizes
|
||||
end
|
||||
|
||||
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
||||
let m_inorm = trainmode(InstanceNorm(2)), m_bnorm = trainmode(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6),
|
||||
let m_inorm = trainmode!(InstanceNorm(2)), m_bnorm = trainmode!(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
||||
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
||||
end
|
||||
|
@ -204,7 +206,7 @@ if VERSION >= v"1.1"
|
|||
@test m.β == [0, 0, 0, 0] # initβ(32)
|
||||
@test m.γ == [1, 1, 1, 1] # initγ(32)
|
||||
|
||||
y = trainmode(m, x)
|
||||
y = evalwgrad(m, x)
|
||||
|
||||
#julia> x
|
||||
#[:, :, 1] =
|
||||
|
@ -263,7 +265,7 @@ if VERSION >= v"1.1"
|
|||
@test isapprox(y, out, atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = trainmode(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3),
|
||||
let m = trainmode!(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||
y = reshape(m(y), sizes...)
|
||||
|
@ -273,20 +275,20 @@ if VERSION >= v"1.1"
|
|||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = trainmode(m, x)
|
||||
y = evalwgrad(m, x)
|
||||
@test size(m.μ) == (m.G,1)
|
||||
@test size(m.σ²) == (m.G,1)
|
||||
@test size(y) == sizes
|
||||
end
|
||||
|
||||
# show that group norm is the same as instance norm when the group size is the same as the number of channels
|
||||
let IN = trainmode(InstanceNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,5),
|
||||
let IN = trainmode!(InstanceNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test IN(x) ≈ GN(x)
|
||||
end
|
||||
|
||||
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
||||
let BN = trainmode(BatchNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,1),
|
||||
let BN = trainmode!(BatchNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,1),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test BN(x) ≈ GN(x)
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue