From 7c12af065a2d8fb20359321e34f3c0731ae5559f Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Thu, 20 Feb 2020 23:27:36 -0600 Subject: [PATCH 01/10] Added testmode! functionality back to normalization layers. --- src/Flux.jl | 2 +- src/layers/normalise.jl | 72 ++++++++++++++++++++++++++++++------ test/layers/normalisation.jl | 31 +++++++++------- 3 files changed, 79 insertions(+), 26 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 9969b323..5f9878f3 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -11,7 +11,7 @@ export gradient export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, - SkipConnection, params, fmap, cpu, gpu, f32, f64 + SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode! include("optimise/Optimise.jl") using .Optimise diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index b421d3e7..ee6b6fdd 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -2,6 +2,23 @@ istraining() = false @adjoint istraining() = true, _ -> nothing +_isactive(m) = isnothing(m.active) ? istraining() : m.active +# @adjoint _isactive(m) = _isactive(m), Δ -> nothing + +""" + testmode!(m, mode = :auto) + +Set a layer or model's test mode (see below). +Using `:auto` mode will treat any gradient computation as training. + +Possible values include: +- `false` for training +- `true` for testing +- `:auto` or `nothing` for Flux to detect the mode automatically +""" +testmode!(m, mode) = nothing +testmode!(m::Chain, mode = :auto) = map(x -> testmode!(x, mode), m.layers) + _dropout_shape(s, ::Colon) = size(s) _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...) @@ -22,18 +39,27 @@ A Dropout layer. For each input, either sets that input to `0` (with probability `p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is used as a regularisation, i.e. it reduces overfitting during training. see also [`dropout`](@ref). + +Does nothing to the input once [`testmode!`](@ref) is false. """ mutable struct Dropout{F,D} p::F dims::D + active::Union{Bool, Nothing} end function Dropout(p; dims = :) @assert 0 ≤ p ≤ 1 - Dropout{typeof(p),typeof(dims)}(p, dims) + Dropout{typeof(p),typeof(dims)}(p, dims, nothing) end -(a::Dropout)(x) = dropout(x, a.p; dims = a.dims) +function (a::Dropout)(x) + _isactive(a) || return x + return dropout(x, a.p; dims = a.dims) +end + +testmode!(m::Dropout, mode = :auto) = + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) function Base.show(io::IO, d::Dropout) print(io, "Dropout(", d.p) @@ -46,17 +72,20 @@ end A dropout layer. It is used in Self-Normalizing Neural Networks. (https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf) The AlphaDropout layer ensures that mean and variance of activations remains the same as before. + +Does nothing to the input once [`testmode!`](@ref) is false. """ mutable struct AlphaDropout{F} p::F - function AlphaDropout(p) + active::Union{Bool, Nothing} + function AlphaDropout(p, active = nothing) @assert 0 ≤ p ≤ 1 - new{typeof(p)}(p) + new{typeof(p)}(p, active) end end function (a::AlphaDropout)(x) - istraining() || return x + _isactive(a) || return x λ = eltype(x)(1.0507009873554804934193349852946) α = eltype(x)(1.6732632423543772848170429916717) α1 = eltype(x)(-λ*α) @@ -68,6 +97,9 @@ function (a::AlphaDropout)(x) return x end +testmode!(m::AlphaDropout, mode = :auto) = + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + """ LayerNorm(h::Integer) @@ -106,6 +138,8 @@ it's the usual channel dimension.) shifts them to have a new mean and variance (corresponding to the learnable, per-channel `bias` and `scale` parameters). +Use [`testmode!`](@ref) during inference. + See [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf). @@ -127,12 +161,13 @@ mutable struct BatchNorm{F,V,W,N} σ²::W # moving std ϵ::N momentum::N + active::Union{Bool, Nothing} end BatchNorm(chs::Integer, λ = identity; initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = BatchNorm(λ, initβ(chs), initγ(chs), - zeros(chs), ones(chs), ϵ, momentum) + zeros(chs), ones(chs), ϵ, momentum, nothing) trainable(bn::BatchNorm) = (bn.β, bn.γ) @@ -145,7 +180,7 @@ function (BN::BatchNorm)(x) m = div(prod(size(x)), channels) γ = reshape(BN.γ, affine_shape...) β = reshape(BN.β, affine_shape...) - if !istraining() + if !_isactive(BN) μ = reshape(BN.μ, affine_shape...) σ² = reshape(BN.σ², affine_shape...) ϵ = BN.ϵ @@ -170,6 +205,9 @@ end @functor BatchNorm +testmode!(m::BatchNorm, mode = :auto) = + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + function Base.show(io::IO, l::BatchNorm) print(io, "BatchNorm($(join(size(l.β), ", "))") (l.λ == identity) || print(io, ", λ = $(l.λ)") @@ -193,6 +231,8 @@ it's the usual channel dimension.) shifts them to have a new mean and variance (corresponding to the learnable, per-channel `bias` and `scale` parameters). +Use [`testmode!`](@ref) during inference. + See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022). Example: @@ -215,12 +255,13 @@ mutable struct InstanceNorm{F,V,W,N} σ²::W # moving std ϵ::N momentum::N + active::Union{Bool, Nothing} end InstanceNorm(chs::Integer, λ = identity; initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = InstanceNorm(λ, initβ(chs), initγ(chs), - zeros(chs), ones(chs), ϵ, momentum) + zeros(chs), ones(chs), ϵ, momentum, nothing) trainable(in::InstanceNorm) = (in.β, in.γ) @@ -237,7 +278,7 @@ function (in::InstanceNorm)(x) m = div(prod(size(x)), c*bs) γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape) - if !istraining() + if !_isactive(in) μ = expand_inst(in.μ, affine_shape) σ² = expand_inst(in.σ², affine_shape) ϵ = in.ϵ @@ -263,6 +304,9 @@ end @functor InstanceNorm +testmode!(m::InstanceNorm, mode = :auto) = + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + function Base.show(io::IO, l::InstanceNorm) print(io, "InstanceNorm($(join(size(l.β), ", "))") (l.λ == identity) || print(io, ", λ = $(l.λ)") @@ -283,6 +327,8 @@ For an array of N dimensions, the (N-1)th index is the channel dimension. ``G`` is the number of groups along which the statistics would be computed. The number of channels must be an integer multiple of the number of groups. +Use [`testmode!`](@ref) during inference. + Example: ``` m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1), @@ -300,12 +346,13 @@ mutable struct GroupNorm{F,V,W,N,T} σ²::W # moving std ϵ::N momentum::N + active::Union{Bool, Nothing} end GroupNorm(chs::Integer, G::Integer, λ = identity; initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = GroupNorm(G, λ, initβ(chs), initγ(chs), - zeros(G,1), ones(G,1), ϵ, momentum) + zeros(G,1), ones(G,1), ϵ, momentum, nothing) trainable(gn::GroupNorm) = (gn.β, gn.γ) @@ -329,7 +376,7 @@ function(gn::GroupNorm)(x) β = reshape(gn.β, affine_shape...) y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches)) - if !istraining() + if !_isactive(gn) og_shape = size(x) μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1) σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1) @@ -360,6 +407,9 @@ end @functor GroupNorm +testmode!(m::GroupNorm, mode = :auto) = + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + function Base.show(io::IO, l::GroupNorm) print(io, "GroupNorm($(join(size(l.β), ", "))") (l.λ == identity) || print(io, ", λ = $(l.λ)") diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 4399a256..594fb586 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -1,30 +1,33 @@ using Flux, Test, Statistics using Zygote: pullback -trainmode(f, x...) = pullback(f, x...)[1] -trainmode(f) = (x...) -> trainmode(f, x...) +evalwgrad(f, x...) = pullback(f, x...)[1] +trainmode(f) = (testmode!(f, false); f) @testset "Dropout" begin x = [1.,2.,3.] @test x == Dropout(0.1)(x) - @test x == trainmode(Dropout(0), x) - @test zero(x) == trainmode(Dropout(1), x) + @test x == evalwgrad(Dropout(0), x) + @test zero(x) == evalwgrad(Dropout(1), x) x = rand(100) m = Dropout(0.9) - y = trainmode(m, x) + y = evalwgrad(m, x) @test count(a->a==0, y) > 50 - y = m(x) + testmode!(m, true) + y = evalwgrad(m, x) # should override istraining @test count(a->a==0, y) == 0 - y = trainmode(m, x) + testmode!(m, false) + y = evalwgrad(m, x) @test count(a->a==0, y) > 50 x = rand(Float32, 100) m = Chain(Dense(100,100), Dropout(0.9)) - y = trainmode(m, x) + y = evalwgrad(m, x) @test count(a->a == 0, y) > 50 - y = m(x) + testmode!(m, true) + y = evalwgrad(m, x) # should override istraining @test count(a->a == 0, y) == 0 x = rand(100, 50) @@ -49,7 +52,7 @@ end # initial m.σ is 1 # initial m.μ is 0 - y = trainmode(m, x) + y = evalwgrad(m, x) @test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5) # julia> x # 2×3 Array{Float64,2}: @@ -117,7 +120,7 @@ end x = Float64.(x) @test m.β == [0, 0] # initβ(2) @test m.γ == [1, 1] # initγ(2) - y = trainmode(m, x) + y = evalwgrad(m, x) #julia> x #[:, :, 1] = @@ -172,7 +175,7 @@ end # check that μ, σ², and the output are the correct size for higher rank tensors let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6), x = reshape(Float32.(collect(1:prod(sizes))), sizes) - y = trainmode(m, x) + y = evalwgrad(m, x) @test size(m.μ) == (sizes[end - 1], ) @test size(m.σ²) == (sizes[end - 1], ) @test size(y) == sizes @@ -204,7 +207,7 @@ if VERSION >= v"1.1" @test m.β == [0, 0, 0, 0] # initβ(32) @test m.γ == [1, 1, 1, 1] # initγ(32) - y = trainmode(m, x) + y = evalwgrad(m, x) #julia> x #[:, :, 1] = @@ -273,7 +276,7 @@ if VERSION >= v"1.1" # check that μ, σ², and the output are the correct size for higher rank tensors let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) - y = trainmode(m, x) + y = evalwgrad(m, x) @test size(m.μ) == (m.G,1) @test size(m.σ²) == (m.G,1) @test size(y) == sizes From 924b8f49ec9a438d35159e4e8ad5fbd75f0654ba Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 21 Feb 2020 15:10:28 -0600 Subject: [PATCH 02/10] Updated to place function definitions in the appropriate places. --- src/functor.jl | 13 +++++++++++++ src/layers/basic.jl | 2 ++ src/layers/normalise.jl | 25 +++++-------------------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index a36b5765..4edfbd98 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -39,6 +39,19 @@ end trainable(m) = functor(m)[1] +""" + testmode!(m, mode = true) + +Set a layer or model's test mode (see below). +Using `:auto` mode will treat any gradient computation as training. + +Possible values include: +- `false` for training +- `true` for testing +- `:auto` or `nothing` for Flux to detect the mode automatically +""" +testmode!(m, mode) = nothing + params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) function params!(p::Params, x, seen = IdSet()) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 2a465208..6788f761 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -33,6 +33,8 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) +testmode!(m::Chain, mode = true) = map(x -> testmode!(x, mode), m.layers) + function Base.show(io::IO, c::Chain) print(io, "Chain(") join(io, c.layers, ", ") diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index ee6b6fdd..7b438bc2 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -3,21 +3,6 @@ istraining() = false @adjoint istraining() = true, _ -> nothing _isactive(m) = isnothing(m.active) ? istraining() : m.active -# @adjoint _isactive(m) = _isactive(m), Δ -> nothing - -""" - testmode!(m, mode = :auto) - -Set a layer or model's test mode (see below). -Using `:auto` mode will treat any gradient computation as training. - -Possible values include: -- `false` for training -- `true` for testing -- `:auto` or `nothing` for Flux to detect the mode automatically -""" -testmode!(m, mode) = nothing -testmode!(m::Chain, mode = :auto) = map(x -> testmode!(x, mode), m.layers) _dropout_shape(s, ::Colon) = size(s) _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...) @@ -58,7 +43,7 @@ function (a::Dropout)(x) return dropout(x, a.p; dims = a.dims) end -testmode!(m::Dropout, mode = :auto) = +testmode!(m::Dropout, mode = true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) function Base.show(io::IO, d::Dropout) @@ -97,7 +82,7 @@ function (a::AlphaDropout)(x) return x end -testmode!(m::AlphaDropout, mode = :auto) = +testmode!(m::AlphaDropout, mode = true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) """ @@ -205,7 +190,7 @@ end @functor BatchNorm -testmode!(m::BatchNorm, mode = :auto) = +testmode!(m::BatchNorm, mode = true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) function Base.show(io::IO, l::BatchNorm) @@ -304,7 +289,7 @@ end @functor InstanceNorm -testmode!(m::InstanceNorm, mode = :auto) = +testmode!(m::InstanceNorm, mode = true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) function Base.show(io::IO, l::InstanceNorm) @@ -407,7 +392,7 @@ end @functor GroupNorm -testmode!(m::GroupNorm, mode = :auto) = +testmode!(m::GroupNorm, mode = true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) function Base.show(io::IO, l::GroupNorm) From ba5259a269f93b0dcf65dfca43b29b219bf81415 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 25 Feb 2020 13:53:49 -0600 Subject: [PATCH 03/10] Added docs on testmode! --- docs/src/models/layers.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 5f2ab3ce..763fbf8c 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -66,6 +66,14 @@ LayerNorm GroupNorm ``` +### Testmode + +Many normalisation layers behave differently under training and inference (testing). By default, Flux will automatically determine when a layer evaluation is part of training or inference. Still, depending on your use case, it may be helpful to manually specify when these layers should be treated as being trained or not. For this, Flux provides `testmode!`. When called on a model (e.g. a layer or chain of layers), this function will place the model into the mode specified. + +```@docs +testmode! +``` + ## Cost Functions ```@docs mse From 5cbd2cecf29cf58a4e4bd97e637515c299a522d8 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 29 Feb 2020 16:09:59 -0600 Subject: [PATCH 04/10] Changed testmode! to return model --- src/functor.jl | 2 +- src/layers/basic.jl | 2 +- src/layers/normalise.jl | 10 +++++----- test/layers/normalisation.jl | 16 ++++++++-------- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index 4edfbd98..ee384b98 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -50,7 +50,7 @@ Possible values include: - `true` for testing - `:auto` or `nothing` for Flux to detect the mode automatically """ -testmode!(m, mode) = nothing +testmode!(m, mode = true) = m params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 6788f761..10d1f07b 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -33,7 +33,7 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) -testmode!(m::Chain, mode = true) = map(x -> testmode!(x, mode), m.layers) +testmode!(m::Chain, mode = true) = (map(x -> testmode!(x, mode), m.layers); m) function Base.show(io::IO, c::Chain) print(io, "Chain(") diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 7b438bc2..36c6d2bd 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -44,7 +44,7 @@ function (a::Dropout)(x) end testmode!(m::Dropout, mode = true) = - (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) function Base.show(io::IO, d::Dropout) print(io, "Dropout(", d.p) @@ -83,7 +83,7 @@ function (a::AlphaDropout)(x) end testmode!(m::AlphaDropout, mode = true) = - (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) """ LayerNorm(h::Integer) @@ -191,7 +191,7 @@ end @functor BatchNorm testmode!(m::BatchNorm, mode = true) = - (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) function Base.show(io::IO, l::BatchNorm) print(io, "BatchNorm($(join(size(l.β), ", "))") @@ -290,7 +290,7 @@ end @functor InstanceNorm testmode!(m::InstanceNorm, mode = true) = - (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) function Base.show(io::IO, l::InstanceNorm) print(io, "InstanceNorm($(join(size(l.β), ", "))") @@ -393,7 +393,7 @@ end @functor GroupNorm testmode!(m::GroupNorm, mode = true) = - (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) function Base.show(io::IO, l::GroupNorm) print(io, "GroupNorm($(join(size(l.β), ", "))") diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 594fb586..79bd9c77 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -85,19 +85,19 @@ end @test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7) end - let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1) + let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:6), 3, 2, 1) y = reshape(permutedims(x, [2, 1, 3]), 2, :) y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3]) @test m(x) == y end - let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1) + let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:12), 2, 3, 2, 1) y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) @test m(x) == y end - let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) + let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :) y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5]) @test m(x) == y @@ -165,7 +165,7 @@ end @test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7) end - let m = trainmode(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3), + let m = testmode!(InstanceNorm(2), false), sizes = (2, 4, 1, 2, 3), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(m(y), sizes...) @@ -182,7 +182,7 @@ end end # show that instance norm is equal to batch norm when channel and batch dims are squashed - let m_inorm = trainmode(InstanceNorm(2)), m_bnorm = trainmode(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6), + let m_inorm = testmode!(InstanceNorm(2), false), m_bnorm = testmode!(BatchNorm(12), false), sizes = (5, 5, 3, 4, 2, 6), x = reshape(Float32.(collect(1:prod(sizes))), sizes) @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes) end @@ -266,7 +266,7 @@ if VERSION >= v"1.1" @test isapprox(y, out, atol = 1.0e-7) end - let m = trainmode(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3), + let m = testmode!(GroupNorm(2,2), false), sizes = (2, 4, 1, 2, 3), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(m(y), sizes...) @@ -283,13 +283,13 @@ if VERSION >= v"1.1" end # show that group norm is the same as instance norm when the group size is the same as the number of channels - let IN = trainmode(InstanceNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,5), + let IN = testmode!(InstanceNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,5), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) @test IN(x) ≈ GN(x) end # show that group norm is the same as batch norm for a group of size 1 and batch of size 1 - let BN = trainmode(BatchNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,1), + let BN = testmode!(BatchNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,1), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) @test BN(x) ≈ GN(x) end From 568ecb1c979a6b05e379d13c2ed2d6ed45f2a71b Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 29 Feb 2020 16:25:18 -0600 Subject: [PATCH 05/10] Removed trainmode from tests --- test/layers/normalisation.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 79bd9c77..f9d4849a 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -2,7 +2,6 @@ using Flux, Test, Statistics using Zygote: pullback evalwgrad(f, x...) = pullback(f, x...)[1] -trainmode(f) = (testmode!(f, false); f) @testset "Dropout" begin x = [1.,2.,3.] From c001d0f3c5cf8613cac2be67821cc6d0561280a4 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 1 Mar 2020 12:30:41 -0600 Subject: [PATCH 06/10] Added trainmode! and updated docs with warning --- docs/src/models/layers.md | 1 + src/Flux.jl | 2 +- src/functor.jl | 21 ++++++++++++++++++++- test/layers/normalisation.jl | 16 ++++++++-------- 4 files changed, 30 insertions(+), 10 deletions(-) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 763fbf8c..100cee4d 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -72,6 +72,7 @@ Many normalisation layers behave differently under training and inference (testi ```@docs testmode! +trainmode! ``` ## Cost Functions diff --git a/src/Flux.jl b/src/Flux.jl index 5f9878f3..163fcdf2 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -11,7 +11,7 @@ export gradient export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, - SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode! + SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode! include("optimise/Optimise.jl") using .Optimise diff --git a/src/functor.jl b/src/functor.jl index ee384b98..fce730b1 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -40,11 +40,14 @@ end trainable(m) = functor(m)[1] """ - testmode!(m, mode = true) + testmode!(m, mode = true) Set a layer or model's test mode (see below). Using `:auto` mode will treat any gradient computation as training. +_Note_: if you manually set a model into test mode, you need to manually place +it back into train mode. + Possible values include: - `false` for training - `true` for testing @@ -52,6 +55,22 @@ Possible values include: """ testmode!(m, mode = true) = m +""" + trainmode!(m, mode = true) + +Set a layer of model's train mode (see below). +Symmetric to [`testmode`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)). + +_Note_: if you manually set a model into train mode, you need to manually place +it into test mode. + +Possible values include: +- `true` for training +- `false` for testing +- `:auto` or `nothing` for Flux to detect the mode automatically +""" +trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode) + params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) function params!(p::Params, x, seen = IdSet()) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index f9d4849a..ed2879b0 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -84,19 +84,19 @@ end @test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7) end - let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:6), 3, 2, 1) + let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1) y = reshape(permutedims(x, [2, 1, 3]), 2, :) y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3]) @test m(x) == y end - let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:12), 2, 3, 2, 1) + let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1) y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) @test m(x) == y end - let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) + let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :) y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5]) @test m(x) == y @@ -164,7 +164,7 @@ end @test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7) end - let m = testmode!(InstanceNorm(2), false), sizes = (2, 4, 1, 2, 3), + let m = trainmode!(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(m(y), sizes...) @@ -181,7 +181,7 @@ end end # show that instance norm is equal to batch norm when channel and batch dims are squashed - let m_inorm = testmode!(InstanceNorm(2), false), m_bnorm = testmode!(BatchNorm(12), false), sizes = (5, 5, 3, 4, 2, 6), + let m_inorm = trainmode!(InstanceNorm(2)), m_bnorm = trainmode!(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6), x = reshape(Float32.(collect(1:prod(sizes))), sizes) @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes) end @@ -265,7 +265,7 @@ if VERSION >= v"1.1" @test isapprox(y, out, atol = 1.0e-7) end - let m = testmode!(GroupNorm(2,2), false), sizes = (2, 4, 1, 2, 3), + let m = trainmode!(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(m(y), sizes...) @@ -282,13 +282,13 @@ if VERSION >= v"1.1" end # show that group norm is the same as instance norm when the group size is the same as the number of channels - let IN = testmode!(InstanceNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,5), + let IN = trainmode!(InstanceNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) @test IN(x) ≈ GN(x) end # show that group norm is the same as batch norm for a group of size 1 and batch of size 1 - let BN = testmode!(BatchNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,1), + let BN = trainmode!(BatchNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,1), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) @test BN(x) ≈ GN(x) end From 35e460b044d47433999c5719111ff1b14138fef2 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 1 Mar 2020 12:44:36 -0600 Subject: [PATCH 07/10] Fixed broken @ref in docstring --- src/functor.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/functor.jl b/src/functor.jl index fce730b1..ba8c9212 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -59,7 +59,7 @@ testmode!(m, mode = true) = m trainmode!(m, mode = true) Set a layer of model's train mode (see below). -Symmetric to [`testmode`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)). +Symmetric to [`testmode!`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)). _Note_: if you manually set a model into train mode, you need to manually place it into test mode. From 23f791e32b6176500d0a48af1afe90b4f8a7958c Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 1 Mar 2020 12:49:30 -0600 Subject: [PATCH 08/10] Add "during X phase" phrasing to testmode!/trainmode! docstring. --- src/functor.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index ba8c9212..0d7c55f1 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -46,7 +46,7 @@ Set a layer or model's test mode (see below). Using `:auto` mode will treat any gradient computation as training. _Note_: if you manually set a model into test mode, you need to manually place -it back into train mode. +it back into train mode during training phase. Possible values include: - `false` for training @@ -62,7 +62,7 @@ Set a layer of model's train mode (see below). Symmetric to [`testmode!`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)). _Note_: if you manually set a model into train mode, you need to manually place -it into test mode. +it into test mode during testing phase. Possible values include: - `true` for training From 88cad1c5e7fb1d16702bff72444a3b91c7bb9469 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 1 Mar 2020 12:50:49 -0600 Subject: [PATCH 09/10] Bump minor version to v0.10.3 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index bd105730..f88d2451 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.10.2" +version = "0.10.3" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" From e49d9c4537714441730f4023b12b168916246137 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 1 Mar 2020 13:11:07 -0600 Subject: [PATCH 10/10] Debump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f88d2451..bd105730 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.10.3" +version = "0.10.2" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"