diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 41e98f32..e2b6b1eb 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -54,6 +54,15 @@ LayerNorm GroupNorm ``` +### Testmode + +Many normalisation layers behave differently under training and inference (testing). By default, Flux will automatically determine when a layer evaluation is part of training or inference. Still, depending on your use case, it may be helpful to manually specify when these layers should be treated as being trained or not. For this, Flux provides `testmode!`. When called on a model (e.g. a layer or chain of layers), this function will place the model into the mode specified. + +```@docs +testmode! +trainmode! +``` + ## Cost Functions ```@docs Flux.mse diff --git a/src/Flux.jl b/src/Flux.jl index c99e41a1..78670e65 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -12,7 +12,7 @@ export gradient export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, - SkipConnection, params, fmap, cpu, gpu, f32, f64 + SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode! include("optimise/Optimise.jl") using .Optimise diff --git a/src/functor.jl b/src/functor.jl index a36b5765..0d7c55f1 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -39,6 +39,38 @@ end trainable(m) = functor(m)[1] +""" + testmode!(m, mode = true) + +Set a layer or model's test mode (see below). +Using `:auto` mode will treat any gradient computation as training. + +_Note_: if you manually set a model into test mode, you need to manually place +it back into train mode during training phase. + +Possible values include: +- `false` for training +- `true` for testing +- `:auto` or `nothing` for Flux to detect the mode automatically +""" +testmode!(m, mode = true) = m + +""" + trainmode!(m, mode = true) + +Set a layer of model's train mode (see below). +Symmetric to [`testmode!`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)). + +_Note_: if you manually set a model into train mode, you need to manually place +it into test mode during testing phase. + +Possible values include: +- `true` for training +- `false` for testing +- `:auto` or `nothing` for Flux to detect the mode automatically +""" +trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode) + params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) function params!(p::Params, x, seen = IdSet()) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 6f056429..24fab689 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -33,6 +33,8 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) +testmode!(m::Chain, mode = true) = (map(x -> testmode!(x, mode), m.layers); m) + function Base.show(io::IO, c::Chain) print(io, "Chain(") join(io, c.layers, ", ") diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 2268fdc0..fc781f70 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -2,6 +2,8 @@ istraining() = false @adjoint istraining() = true, _ -> nothing +_isactive(m) = isnothing(m.active) ? istraining() : m.active + _dropout_shape(s, ::Colon) = size(s) _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...) @@ -29,18 +31,27 @@ end Dropout(p, dims = :) A Dropout layer. In the forward pass, applies the [`dropout`](@ref) function on the input. + +Does nothing to the input once [`testmode!`](@ref) is false. """ mutable struct Dropout{F,D} p::F dims::D + active::Union{Bool, Nothing} end function Dropout(p; dims = :) @assert 0 ≤ p ≤ 1 - Dropout{typeof(p),typeof(dims)}(p, dims) + Dropout{typeof(p),typeof(dims)}(p, dims, nothing) end -(a::Dropout)(x) = dropout(x, a.p; dims = a.dims) +function (a::Dropout)(x) + _isactive(a) || return x + return dropout(x, a.p; dims = a.dims) +end + +testmode!(m::Dropout, mode = true) = + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) function Base.show(io::IO, d::Dropout) print(io, "Dropout(", d.p) @@ -54,17 +65,20 @@ end A dropout layer. It is used in Self-Normalizing Neural Networks. (https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf) The AlphaDropout layer ensures that mean and variance of activations remains the same as before. + +Does nothing to the input once [`testmode!`](@ref) is false. """ mutable struct AlphaDropout{F} p::F - function AlphaDropout(p) + active::Union{Bool, Nothing} + function AlphaDropout(p, active = nothing) @assert 0 ≤ p ≤ 1 - new{typeof(p)}(p) + new{typeof(p)}(p, active) end end function (a::AlphaDropout)(x) - istraining() || return x + _isactive(a) || return x λ = eltype(x)(1.0507009873554804934193349852946) α = eltype(x)(1.6732632423543772848170429916717) α1 = eltype(x)(-λ*α) @@ -76,6 +90,9 @@ function (a::AlphaDropout)(x) return x end +testmode!(m::AlphaDropout, mode = true) = + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) + """ LayerNorm(h::Integer) @@ -114,6 +131,8 @@ it's the usual channel dimension.) shifts them to have a new mean and variance (corresponding to the learnable, per-channel `bias` and `scale` parameters). +Use [`testmode!`](@ref) during inference. + See [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf). @@ -135,12 +154,13 @@ mutable struct BatchNorm{F,V,W,N} σ²::W # moving std ϵ::N momentum::N + active::Union{Bool, Nothing} end BatchNorm(chs::Integer, λ = identity; initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = BatchNorm(λ, initβ(chs), initγ(chs), - zeros(chs), ones(chs), ϵ, momentum) + zeros(chs), ones(chs), ϵ, momentum, nothing) trainable(bn::BatchNorm) = (bn.β, bn.γ) @@ -153,7 +173,7 @@ function (BN::BatchNorm)(x) m = div(prod(size(x)), channels) γ = reshape(BN.γ, affine_shape...) β = reshape(BN.β, affine_shape...) - if !istraining() + if !_isactive(BN) μ = reshape(BN.μ, affine_shape...) σ² = reshape(BN.σ², affine_shape...) ϵ = BN.ϵ @@ -178,6 +198,9 @@ end @functor BatchNorm +testmode!(m::BatchNorm, mode = true) = + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) + function Base.show(io::IO, l::BatchNorm) print(io, "BatchNorm($(join(size(l.β), ", "))") (l.λ == identity) || print(io, ", λ = $(l.λ)") @@ -201,6 +224,8 @@ it's the usual channel dimension.) shifts them to have a new mean and variance (corresponding to the learnable, per-channel `bias` and `scale` parameters). +Use [`testmode!`](@ref) during inference. + See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022). Example: @@ -223,12 +248,13 @@ mutable struct InstanceNorm{F,V,W,N} σ²::W # moving std ϵ::N momentum::N + active::Union{Bool, Nothing} end InstanceNorm(chs::Integer, λ = identity; initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = InstanceNorm(λ, initβ(chs), initγ(chs), - zeros(chs), ones(chs), ϵ, momentum) + zeros(chs), ones(chs), ϵ, momentum, nothing) trainable(in::InstanceNorm) = (in.β, in.γ) @@ -245,7 +271,7 @@ function (in::InstanceNorm)(x) m = div(prod(size(x)), c*bs) γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape) - if !istraining() + if !_isactive(in) μ = expand_inst(in.μ, affine_shape) σ² = expand_inst(in.σ², affine_shape) ϵ = in.ϵ @@ -271,6 +297,9 @@ end @functor InstanceNorm +testmode!(m::InstanceNorm, mode = true) = + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) + function Base.show(io::IO, l::InstanceNorm) print(io, "InstanceNorm($(join(size(l.β), ", "))") (l.λ == identity) || print(io, ", λ = $(l.λ)") @@ -291,6 +320,8 @@ For an array of N dimensions, the (N-1)th index is the channel dimension. ``G`` is the number of groups along which the statistics would be computed. The number of channels must be an integer multiple of the number of groups. +Use [`testmode!`](@ref) during inference. + Example: ``` m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1), @@ -308,12 +339,13 @@ mutable struct GroupNorm{F,V,W,N,T} σ²::W # moving std ϵ::N momentum::N + active::Union{Bool, Nothing} end GroupNorm(chs::Integer, G::Integer, λ = identity; initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = GroupNorm(G, λ, initβ(chs), initγ(chs), - zeros(G,1), ones(G,1), ϵ, momentum) + zeros(G,1), ones(G,1), ϵ, momentum, nothing) trainable(gn::GroupNorm) = (gn.β, gn.γ) @@ -337,7 +369,7 @@ function(gn::GroupNorm)(x) β = reshape(gn.β, affine_shape...) y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches)) - if !istraining() + if !_isactive(gn) og_shape = size(x) μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1) σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1) @@ -368,6 +400,9 @@ end @functor GroupNorm +testmode!(m::GroupNorm, mode = true) = + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) + function Base.show(io::IO, l::GroupNorm) print(io, "GroupNorm($(join(size(l.β), ", "))") (l.λ == identity) || print(io, ", λ = $(l.λ)") diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 4399a256..ed2879b0 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -1,30 +1,32 @@ using Flux, Test, Statistics using Zygote: pullback -trainmode(f, x...) = pullback(f, x...)[1] -trainmode(f) = (x...) -> trainmode(f, x...) +evalwgrad(f, x...) = pullback(f, x...)[1] @testset "Dropout" begin x = [1.,2.,3.] @test x == Dropout(0.1)(x) - @test x == trainmode(Dropout(0), x) - @test zero(x) == trainmode(Dropout(1), x) + @test x == evalwgrad(Dropout(0), x) + @test zero(x) == evalwgrad(Dropout(1), x) x = rand(100) m = Dropout(0.9) - y = trainmode(m, x) + y = evalwgrad(m, x) @test count(a->a==0, y) > 50 - y = m(x) + testmode!(m, true) + y = evalwgrad(m, x) # should override istraining @test count(a->a==0, y) == 0 - y = trainmode(m, x) + testmode!(m, false) + y = evalwgrad(m, x) @test count(a->a==0, y) > 50 x = rand(Float32, 100) m = Chain(Dense(100,100), Dropout(0.9)) - y = trainmode(m, x) + y = evalwgrad(m, x) @test count(a->a == 0, y) > 50 - y = m(x) + testmode!(m, true) + y = evalwgrad(m, x) # should override istraining @test count(a->a == 0, y) == 0 x = rand(100, 50) @@ -49,7 +51,7 @@ end # initial m.σ is 1 # initial m.μ is 0 - y = trainmode(m, x) + y = evalwgrad(m, x) @test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5) # julia> x # 2×3 Array{Float64,2}: @@ -82,19 +84,19 @@ end @test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7) end - let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1) + let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1) y = reshape(permutedims(x, [2, 1, 3]), 2, :) y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3]) @test m(x) == y end - let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1) + let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1) y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) @test m(x) == y end - let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) + let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :) y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5]) @test m(x) == y @@ -117,7 +119,7 @@ end x = Float64.(x) @test m.β == [0, 0] # initβ(2) @test m.γ == [1, 1] # initγ(2) - y = trainmode(m, x) + y = evalwgrad(m, x) #julia> x #[:, :, 1] = @@ -162,7 +164,7 @@ end @test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7) end - let m = trainmode(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3), + let m = trainmode!(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(m(y), sizes...) @@ -172,14 +174,14 @@ end # check that μ, σ², and the output are the correct size for higher rank tensors let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6), x = reshape(Float32.(collect(1:prod(sizes))), sizes) - y = trainmode(m, x) + y = evalwgrad(m, x) @test size(m.μ) == (sizes[end - 1], ) @test size(m.σ²) == (sizes[end - 1], ) @test size(y) == sizes end # show that instance norm is equal to batch norm when channel and batch dims are squashed - let m_inorm = trainmode(InstanceNorm(2)), m_bnorm = trainmode(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6), + let m_inorm = trainmode!(InstanceNorm(2)), m_bnorm = trainmode!(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6), x = reshape(Float32.(collect(1:prod(sizes))), sizes) @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes) end @@ -204,7 +206,7 @@ if VERSION >= v"1.1" @test m.β == [0, 0, 0, 0] # initβ(32) @test m.γ == [1, 1, 1, 1] # initγ(32) - y = trainmode(m, x) + y = evalwgrad(m, x) #julia> x #[:, :, 1] = @@ -263,7 +265,7 @@ if VERSION >= v"1.1" @test isapprox(y, out, atol = 1.0e-7) end - let m = trainmode(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3), + let m = trainmode!(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(m(y), sizes...) @@ -273,20 +275,20 @@ if VERSION >= v"1.1" # check that μ, σ², and the output are the correct size for higher rank tensors let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) - y = trainmode(m, x) + y = evalwgrad(m, x) @test size(m.μ) == (m.G,1) @test size(m.σ²) == (m.G,1) @test size(y) == sizes end # show that group norm is the same as instance norm when the group size is the same as the number of channels - let IN = trainmode(InstanceNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,5), + let IN = trainmode!(InstanceNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) @test IN(x) ≈ GN(x) end # show that group norm is the same as batch norm for a group of size 1 and batch of size 1 - let BN = trainmode(BatchNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,1), + let BN = trainmode!(BatchNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,1), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) @test BN(x) ≈ GN(x) end