Added testmode! functionality back to normalization layers.

This commit is contained in:
Kyle Daruwalla 2020-02-20 23:27:36 -06:00
parent 88b0c65d72
commit 7c12af065a
3 changed files with 79 additions and 26 deletions

View File

@ -11,7 +11,7 @@ export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection, params, fmap, cpu, gpu, f32, f64 SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!
include("optimise/Optimise.jl") include("optimise/Optimise.jl")
using .Optimise using .Optimise

View File

@ -2,6 +2,23 @@ istraining() = false
@adjoint istraining() = true, _ -> nothing @adjoint istraining() = true, _ -> nothing
_isactive(m) = isnothing(m.active) ? istraining() : m.active
# @adjoint _isactive(m) = _isactive(m), Δ -> nothing
"""
testmode!(m, mode = :auto)
Set a layer or model's test mode (see below).
Using `:auto` mode will treat any gradient computation as training.
Possible values include:
- `false` for training
- `true` for testing
- `:auto` or `nothing` for Flux to detect the mode automatically
"""
testmode!(m, mode) = nothing
testmode!(m::Chain, mode = :auto) = map(x -> testmode!(x, mode), m.layers)
_dropout_shape(s, ::Colon) = size(s) _dropout_shape(s, ::Colon) = size(s)
_dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...) _dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...)
@ -22,18 +39,27 @@ A Dropout layer. For each input, either sets that input to `0` (with probability
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted `p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
used as a regularisation, i.e. it reduces overfitting during training. see also [`dropout`](@ref). used as a regularisation, i.e. it reduces overfitting during training. see also [`dropout`](@ref).
Does nothing to the input once [`testmode!`](@ref) is false.
""" """
mutable struct Dropout{F,D} mutable struct Dropout{F,D}
p::F p::F
dims::D dims::D
active::Union{Bool, Nothing}
end end
function Dropout(p; dims = :) function Dropout(p; dims = :)
@assert 0 p 1 @assert 0 p 1
Dropout{typeof(p),typeof(dims)}(p, dims) Dropout{typeof(p),typeof(dims)}(p, dims, nothing)
end end
(a::Dropout)(x) = dropout(x, a.p; dims = a.dims) function (a::Dropout)(x)
_isactive(a) || return x
return dropout(x, a.p; dims = a.dims)
end
testmode!(m::Dropout, mode = :auto) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
function Base.show(io::IO, d::Dropout) function Base.show(io::IO, d::Dropout)
print(io, "Dropout(", d.p) print(io, "Dropout(", d.p)
@ -46,17 +72,20 @@ end
A dropout layer. It is used in Self-Normalizing Neural Networks. A dropout layer. It is used in Self-Normalizing Neural Networks.
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf) (https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
The AlphaDropout layer ensures that mean and variance of activations remains the same as before. The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
Does nothing to the input once [`testmode!`](@ref) is false.
""" """
mutable struct AlphaDropout{F} mutable struct AlphaDropout{F}
p::F p::F
function AlphaDropout(p) active::Union{Bool, Nothing}
function AlphaDropout(p, active = nothing)
@assert 0 p 1 @assert 0 p 1
new{typeof(p)}(p) new{typeof(p)}(p, active)
end end
end end
function (a::AlphaDropout)(x) function (a::AlphaDropout)(x)
istraining() || return x _isactive(a) || return x
λ = eltype(x)(1.0507009873554804934193349852946) λ = eltype(x)(1.0507009873554804934193349852946)
α = eltype(x)(1.6732632423543772848170429916717) α = eltype(x)(1.6732632423543772848170429916717)
α1 = eltype(x)(-λ*α) α1 = eltype(x)(-λ*α)
@ -68,6 +97,9 @@ function (a::AlphaDropout)(x)
return x return x
end end
testmode!(m::AlphaDropout, mode = :auto) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
""" """
LayerNorm(h::Integer) LayerNorm(h::Integer)
@ -106,6 +138,8 @@ it's the usual channel dimension.)
shifts them to have a new mean and variance (corresponding to the learnable, shifts them to have a new mean and variance (corresponding to the learnable,
per-channel `bias` and `scale` parameters). per-channel `bias` and `scale` parameters).
Use [`testmode!`](@ref) during inference.
See [Batch Normalization: Accelerating Deep Network Training by Reducing See [Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf). Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
@ -127,12 +161,13 @@ mutable struct BatchNorm{F,V,W,N}
σ²::W # moving std σ²::W # moving std
ϵ::N ϵ::N
momentum::N momentum::N
active::Union{Bool, Nothing}
end end
BatchNorm(chs::Integer, λ = identity; BatchNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
BatchNorm(λ, initβ(chs), initγ(chs), BatchNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum) zeros(chs), ones(chs), ϵ, momentum, nothing)
trainable(bn::BatchNorm) = (bn.β, bn.γ) trainable(bn::BatchNorm) = (bn.β, bn.γ)
@ -145,7 +180,7 @@ function (BN::BatchNorm)(x)
m = div(prod(size(x)), channels) m = div(prod(size(x)), channels)
γ = reshape(BN.γ, affine_shape...) γ = reshape(BN.γ, affine_shape...)
β = reshape(BN.β, affine_shape...) β = reshape(BN.β, affine_shape...)
if !istraining() if !_isactive(BN)
μ = reshape(BN.μ, affine_shape...) μ = reshape(BN.μ, affine_shape...)
σ² = reshape(BN.σ², affine_shape...) σ² = reshape(BN.σ², affine_shape...)
ϵ = BN.ϵ ϵ = BN.ϵ
@ -170,6 +205,9 @@ end
@functor BatchNorm @functor BatchNorm
testmode!(m::BatchNorm, mode = :auto) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
function Base.show(io::IO, l::BatchNorm) function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(join(size(l.β), ", "))") print(io, "BatchNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)") (l.λ == identity) || print(io, ", λ = $(l.λ)")
@ -193,6 +231,8 @@ it's the usual channel dimension.)
shifts them to have a new mean and variance (corresponding to the learnable, shifts them to have a new mean and variance (corresponding to the learnable,
per-channel `bias` and `scale` parameters). per-channel `bias` and `scale` parameters).
Use [`testmode!`](@ref) during inference.
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022). See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
Example: Example:
@ -215,12 +255,13 @@ mutable struct InstanceNorm{F,V,W,N}
σ²::W # moving std σ²::W # moving std
ϵ::N ϵ::N
momentum::N momentum::N
active::Union{Bool, Nothing}
end end
InstanceNorm(chs::Integer, λ = identity; InstanceNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
InstanceNorm(λ, initβ(chs), initγ(chs), InstanceNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum) zeros(chs), ones(chs), ϵ, momentum, nothing)
trainable(in::InstanceNorm) = (in.β, in.γ) trainable(in::InstanceNorm) = (in.β, in.γ)
@ -237,7 +278,7 @@ function (in::InstanceNorm)(x)
m = div(prod(size(x)), c*bs) m = div(prod(size(x)), c*bs)
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape) γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
if !istraining() if !_isactive(in)
μ = expand_inst(in.μ, affine_shape) μ = expand_inst(in.μ, affine_shape)
σ² = expand_inst(in.σ², affine_shape) σ² = expand_inst(in.σ², affine_shape)
ϵ = in.ϵ ϵ = in.ϵ
@ -263,6 +304,9 @@ end
@functor InstanceNorm @functor InstanceNorm
testmode!(m::InstanceNorm, mode = :auto) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
function Base.show(io::IO, l::InstanceNorm) function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(join(size(l.β), ", "))") print(io, "InstanceNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)") (l.λ == identity) || print(io, ", λ = $(l.λ)")
@ -283,6 +327,8 @@ For an array of N dimensions, the (N-1)th index is the channel dimension.
``G`` is the number of groups along which the statistics would be computed. ``G`` is the number of groups along which the statistics would be computed.
The number of channels must be an integer multiple of the number of groups. The number of channels must be an integer multiple of the number of groups.
Use [`testmode!`](@ref) during inference.
Example: Example:
``` ```
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1), m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
@ -300,12 +346,13 @@ mutable struct GroupNorm{F,V,W,N,T}
σ²::W # moving std σ²::W # moving std
ϵ::N ϵ::N
momentum::N momentum::N
active::Union{Bool, Nothing}
end end
GroupNorm(chs::Integer, G::Integer, λ = identity; GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
GroupNorm(G, λ, initβ(chs), initγ(chs), GroupNorm(G, λ, initβ(chs), initγ(chs),
zeros(G,1), ones(G,1), ϵ, momentum) zeros(G,1), ones(G,1), ϵ, momentum, nothing)
trainable(gn::GroupNorm) = (gn.β, gn.γ) trainable(gn::GroupNorm) = (gn.β, gn.γ)
@ -329,7 +376,7 @@ function(gn::GroupNorm)(x)
β = reshape(gn.β, affine_shape...) β = reshape(gn.β, affine_shape...)
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches)) y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
if !istraining() if !_isactive(gn)
og_shape = size(x) og_shape = size(x)
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1) μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1) σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
@ -360,6 +407,9 @@ end
@functor GroupNorm @functor GroupNorm
testmode!(m::GroupNorm, mode = :auto) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
function Base.show(io::IO, l::GroupNorm) function Base.show(io::IO, l::GroupNorm)
print(io, "GroupNorm($(join(size(l.β), ", "))") print(io, "GroupNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)") (l.λ == identity) || print(io, ", λ = $(l.λ)")

View File

@ -1,30 +1,33 @@
using Flux, Test, Statistics using Flux, Test, Statistics
using Zygote: pullback using Zygote: pullback
trainmode(f, x...) = pullback(f, x...)[1] evalwgrad(f, x...) = pullback(f, x...)[1]
trainmode(f) = (x...) -> trainmode(f, x...) trainmode(f) = (testmode!(f, false); f)
@testset "Dropout" begin @testset "Dropout" begin
x = [1.,2.,3.] x = [1.,2.,3.]
@test x == Dropout(0.1)(x) @test x == Dropout(0.1)(x)
@test x == trainmode(Dropout(0), x) @test x == evalwgrad(Dropout(0), x)
@test zero(x) == trainmode(Dropout(1), x) @test zero(x) == evalwgrad(Dropout(1), x)
x = rand(100) x = rand(100)
m = Dropout(0.9) m = Dropout(0.9)
y = trainmode(m, x) y = evalwgrad(m, x)
@test count(a->a==0, y) > 50 @test count(a->a==0, y) > 50
y = m(x) testmode!(m, true)
y = evalwgrad(m, x) # should override istraining
@test count(a->a==0, y) == 0 @test count(a->a==0, y) == 0
y = trainmode(m, x) testmode!(m, false)
y = evalwgrad(m, x)
@test count(a->a==0, y) > 50 @test count(a->a==0, y) > 50
x = rand(Float32, 100) x = rand(Float32, 100)
m = Chain(Dense(100,100), m = Chain(Dense(100,100),
Dropout(0.9)) Dropout(0.9))
y = trainmode(m, x) y = evalwgrad(m, x)
@test count(a->a == 0, y) > 50 @test count(a->a == 0, y) > 50
y = m(x) testmode!(m, true)
y = evalwgrad(m, x) # should override istraining
@test count(a->a == 0, y) == 0 @test count(a->a == 0, y) == 0
x = rand(100, 50) x = rand(100, 50)
@ -49,7 +52,7 @@ end
# initial m.σ is 1 # initial m.σ is 1
# initial m.μ is 0 # initial m.μ is 0
y = trainmode(m, x) y = evalwgrad(m, x)
@test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5) @test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5)
# julia> x # julia> x
# 2×3 Array{Float64,2}: # 2×3 Array{Float64,2}:
@ -117,7 +120,7 @@ end
x = Float64.(x) x = Float64.(x)
@test m.β == [0, 0] # initβ(2) @test m.β == [0, 0] # initβ(2)
@test m.γ == [1, 1] # initγ(2) @test m.γ == [1, 1] # initγ(2)
y = trainmode(m, x) y = evalwgrad(m, x)
#julia> x #julia> x
#[:, :, 1] = #[:, :, 1] =
@ -172,7 +175,7 @@ end
# check that μ, σ², and the output are the correct size for higher rank tensors # check that μ, σ², and the output are the correct size for higher rank tensors
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6), let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
x = reshape(Float32.(collect(1:prod(sizes))), sizes) x = reshape(Float32.(collect(1:prod(sizes))), sizes)
y = trainmode(m, x) y = evalwgrad(m, x)
@test size(m.μ) == (sizes[end - 1], ) @test size(m.μ) == (sizes[end - 1], )
@test size(m.σ²) == (sizes[end - 1], ) @test size(m.σ²) == (sizes[end - 1], )
@test size(y) == sizes @test size(y) == sizes
@ -204,7 +207,7 @@ if VERSION >= v"1.1"
@test m.β == [0, 0, 0, 0] # initβ(32) @test m.β == [0, 0, 0, 0] # initβ(32)
@test m.γ == [1, 1, 1, 1] # initγ(32) @test m.γ == [1, 1, 1, 1] # initγ(32)
y = trainmode(m, x) y = evalwgrad(m, x)
#julia> x #julia> x
#[:, :, 1] = #[:, :, 1] =
@ -273,7 +276,7 @@ if VERSION >= v"1.1"
# check that μ, σ², and the output are the correct size for higher rank tensors # check that μ, σ², and the output are the correct size for higher rank tensors
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6), let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
x = Float32.(reshape(collect(1:prod(sizes)), sizes)) x = Float32.(reshape(collect(1:prod(sizes)), sizes))
y = trainmode(m, x) y = evalwgrad(m, x)
@test size(m.μ) == (m.G,1) @test size(m.μ) == (m.G,1)
@test size(m.σ²) == (m.G,1) @test size(m.σ²) == (m.G,1)
@test size(y) == sizes @test size(y) == sizes