diff --git a/src/Flux.jl b/src/Flux.jl index 9969b323..5f9878f3 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -11,7 +11,7 @@ export gradient export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, - SkipConnection, params, fmap, cpu, gpu, f32, f64 + SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode! include("optimise/Optimise.jl") using .Optimise diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index b421d3e7..ee6b6fdd 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -2,6 +2,23 @@ istraining() = false @adjoint istraining() = true, _ -> nothing +_isactive(m) = isnothing(m.active) ? istraining() : m.active +# @adjoint _isactive(m) = _isactive(m), Δ -> nothing + +""" + testmode!(m, mode = :auto) + +Set a layer or model's test mode (see below). +Using `:auto` mode will treat any gradient computation as training. + +Possible values include: +- `false` for training +- `true` for testing +- `:auto` or `nothing` for Flux to detect the mode automatically +""" +testmode!(m, mode) = nothing +testmode!(m::Chain, mode = :auto) = map(x -> testmode!(x, mode), m.layers) + _dropout_shape(s, ::Colon) = size(s) _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...) @@ -22,18 +39,27 @@ A Dropout layer. For each input, either sets that input to `0` (with probability `p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is used as a regularisation, i.e. it reduces overfitting during training. see also [`dropout`](@ref). + +Does nothing to the input once [`testmode!`](@ref) is false. """ mutable struct Dropout{F,D} p::F dims::D + active::Union{Bool, Nothing} end function Dropout(p; dims = :) @assert 0 ≤ p ≤ 1 - Dropout{typeof(p),typeof(dims)}(p, dims) + Dropout{typeof(p),typeof(dims)}(p, dims, nothing) end -(a::Dropout)(x) = dropout(x, a.p; dims = a.dims) +function (a::Dropout)(x) + _isactive(a) || return x + return dropout(x, a.p; dims = a.dims) +end + +testmode!(m::Dropout, mode = :auto) = + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) function Base.show(io::IO, d::Dropout) print(io, "Dropout(", d.p) @@ -46,17 +72,20 @@ end A dropout layer. It is used in Self-Normalizing Neural Networks. (https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf) The AlphaDropout layer ensures that mean and variance of activations remains the same as before. + +Does nothing to the input once [`testmode!`](@ref) is false. """ mutable struct AlphaDropout{F} p::F - function AlphaDropout(p) + active::Union{Bool, Nothing} + function AlphaDropout(p, active = nothing) @assert 0 ≤ p ≤ 1 - new{typeof(p)}(p) + new{typeof(p)}(p, active) end end function (a::AlphaDropout)(x) - istraining() || return x + _isactive(a) || return x λ = eltype(x)(1.0507009873554804934193349852946) α = eltype(x)(1.6732632423543772848170429916717) α1 = eltype(x)(-λ*α) @@ -68,6 +97,9 @@ function (a::AlphaDropout)(x) return x end +testmode!(m::AlphaDropout, mode = :auto) = + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + """ LayerNorm(h::Integer) @@ -106,6 +138,8 @@ it's the usual channel dimension.) shifts them to have a new mean and variance (corresponding to the learnable, per-channel `bias` and `scale` parameters). +Use [`testmode!`](@ref) during inference. + See [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf). @@ -127,12 +161,13 @@ mutable struct BatchNorm{F,V,W,N} σ²::W # moving std ϵ::N momentum::N + active::Union{Bool, Nothing} end BatchNorm(chs::Integer, λ = identity; initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = BatchNorm(λ, initβ(chs), initγ(chs), - zeros(chs), ones(chs), ϵ, momentum) + zeros(chs), ones(chs), ϵ, momentum, nothing) trainable(bn::BatchNorm) = (bn.β, bn.γ) @@ -145,7 +180,7 @@ function (BN::BatchNorm)(x) m = div(prod(size(x)), channels) γ = reshape(BN.γ, affine_shape...) β = reshape(BN.β, affine_shape...) - if !istraining() + if !_isactive(BN) μ = reshape(BN.μ, affine_shape...) σ² = reshape(BN.σ², affine_shape...) ϵ = BN.ϵ @@ -170,6 +205,9 @@ end @functor BatchNorm +testmode!(m::BatchNorm, mode = :auto) = + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + function Base.show(io::IO, l::BatchNorm) print(io, "BatchNorm($(join(size(l.β), ", "))") (l.λ == identity) || print(io, ", λ = $(l.λ)") @@ -193,6 +231,8 @@ it's the usual channel dimension.) shifts them to have a new mean and variance (corresponding to the learnable, per-channel `bias` and `scale` parameters). +Use [`testmode!`](@ref) during inference. + See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022). Example: @@ -215,12 +255,13 @@ mutable struct InstanceNorm{F,V,W,N} σ²::W # moving std ϵ::N momentum::N + active::Union{Bool, Nothing} end InstanceNorm(chs::Integer, λ = identity; initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = InstanceNorm(λ, initβ(chs), initγ(chs), - zeros(chs), ones(chs), ϵ, momentum) + zeros(chs), ones(chs), ϵ, momentum, nothing) trainable(in::InstanceNorm) = (in.β, in.γ) @@ -237,7 +278,7 @@ function (in::InstanceNorm)(x) m = div(prod(size(x)), c*bs) γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape) - if !istraining() + if !_isactive(in) μ = expand_inst(in.μ, affine_shape) σ² = expand_inst(in.σ², affine_shape) ϵ = in.ϵ @@ -263,6 +304,9 @@ end @functor InstanceNorm +testmode!(m::InstanceNorm, mode = :auto) = + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + function Base.show(io::IO, l::InstanceNorm) print(io, "InstanceNorm($(join(size(l.β), ", "))") (l.λ == identity) || print(io, ", λ = $(l.λ)") @@ -283,6 +327,8 @@ For an array of N dimensions, the (N-1)th index is the channel dimension. ``G`` is the number of groups along which the statistics would be computed. The number of channels must be an integer multiple of the number of groups. +Use [`testmode!`](@ref) during inference. + Example: ``` m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1), @@ -300,12 +346,13 @@ mutable struct GroupNorm{F,V,W,N,T} σ²::W # moving std ϵ::N momentum::N + active::Union{Bool, Nothing} end GroupNorm(chs::Integer, G::Integer, λ = identity; initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = GroupNorm(G, λ, initβ(chs), initγ(chs), - zeros(G,1), ones(G,1), ϵ, momentum) + zeros(G,1), ones(G,1), ϵ, momentum, nothing) trainable(gn::GroupNorm) = (gn.β, gn.γ) @@ -329,7 +376,7 @@ function(gn::GroupNorm)(x) β = reshape(gn.β, affine_shape...) y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches)) - if !istraining() + if !_isactive(gn) og_shape = size(x) μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1) σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1) @@ -360,6 +407,9 @@ end @functor GroupNorm +testmode!(m::GroupNorm, mode = :auto) = + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + function Base.show(io::IO, l::GroupNorm) print(io, "GroupNorm($(join(size(l.β), ", "))") (l.λ == identity) || print(io, ", λ = $(l.λ)") diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 4399a256..594fb586 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -1,30 +1,33 @@ using Flux, Test, Statistics using Zygote: pullback -trainmode(f, x...) = pullback(f, x...)[1] -trainmode(f) = (x...) -> trainmode(f, x...) +evalwgrad(f, x...) = pullback(f, x...)[1] +trainmode(f) = (testmode!(f, false); f) @testset "Dropout" begin x = [1.,2.,3.] @test x == Dropout(0.1)(x) - @test x == trainmode(Dropout(0), x) - @test zero(x) == trainmode(Dropout(1), x) + @test x == evalwgrad(Dropout(0), x) + @test zero(x) == evalwgrad(Dropout(1), x) x = rand(100) m = Dropout(0.9) - y = trainmode(m, x) + y = evalwgrad(m, x) @test count(a->a==0, y) > 50 - y = m(x) + testmode!(m, true) + y = evalwgrad(m, x) # should override istraining @test count(a->a==0, y) == 0 - y = trainmode(m, x) + testmode!(m, false) + y = evalwgrad(m, x) @test count(a->a==0, y) > 50 x = rand(Float32, 100) m = Chain(Dense(100,100), Dropout(0.9)) - y = trainmode(m, x) + y = evalwgrad(m, x) @test count(a->a == 0, y) > 50 - y = m(x) + testmode!(m, true) + y = evalwgrad(m, x) # should override istraining @test count(a->a == 0, y) == 0 x = rand(100, 50) @@ -49,7 +52,7 @@ end # initial m.σ is 1 # initial m.μ is 0 - y = trainmode(m, x) + y = evalwgrad(m, x) @test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5) # julia> x # 2×3 Array{Float64,2}: @@ -117,7 +120,7 @@ end x = Float64.(x) @test m.β == [0, 0] # initβ(2) @test m.γ == [1, 1] # initγ(2) - y = trainmode(m, x) + y = evalwgrad(m, x) #julia> x #[:, :, 1] = @@ -172,7 +175,7 @@ end # check that μ, σ², and the output are the correct size for higher rank tensors let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6), x = reshape(Float32.(collect(1:prod(sizes))), sizes) - y = trainmode(m, x) + y = evalwgrad(m, x) @test size(m.μ) == (sizes[end - 1], ) @test size(m.σ²) == (sizes[end - 1], ) @test size(y) == sizes @@ -204,7 +207,7 @@ if VERSION >= v"1.1" @test m.β == [0, 0, 0, 0] # initβ(32) @test m.γ == [1, 1, 1, 1] # initγ(32) - y = trainmode(m, x) + y = evalwgrad(m, x) #julia> x #[:, :, 1] = @@ -273,7 +276,7 @@ if VERSION >= v"1.1" # check that μ, σ², and the output are the correct size for higher rank tensors let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) - y = trainmode(m, x) + y = evalwgrad(m, x) @test size(m.μ) == (m.G,1) @test size(m.σ²) == (m.G,1) @test size(y) == sizes