From c001d0f3c5cf8613cac2be67821cc6d0561280a4 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 1 Mar 2020 12:30:41 -0600 Subject: [PATCH] Added trainmode! and updated docs with warning --- docs/src/models/layers.md | 1 + src/Flux.jl | 2 +- src/functor.jl | 21 ++++++++++++++++++++- test/layers/normalisation.jl | 16 ++++++++-------- 4 files changed, 30 insertions(+), 10 deletions(-) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 763fbf8c..100cee4d 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -72,6 +72,7 @@ Many normalisation layers behave differently under training and inference (testi ```@docs testmode! +trainmode! ``` ## Cost Functions diff --git a/src/Flux.jl b/src/Flux.jl index 5f9878f3..163fcdf2 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -11,7 +11,7 @@ export gradient export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, - SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode! + SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode! include("optimise/Optimise.jl") using .Optimise diff --git a/src/functor.jl b/src/functor.jl index ee384b98..fce730b1 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -40,11 +40,14 @@ end trainable(m) = functor(m)[1] """ - testmode!(m, mode = true) + testmode!(m, mode = true) Set a layer or model's test mode (see below). Using `:auto` mode will treat any gradient computation as training. +_Note_: if you manually set a model into test mode, you need to manually place +it back into train mode. + Possible values include: - `false` for training - `true` for testing @@ -52,6 +55,22 @@ Possible values include: """ testmode!(m, mode = true) = m +""" + trainmode!(m, mode = true) + +Set a layer of model's train mode (see below). +Symmetric to [`testmode`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)). + +_Note_: if you manually set a model into train mode, you need to manually place +it into test mode. + +Possible values include: +- `true` for training +- `false` for testing +- `:auto` or `nothing` for Flux to detect the mode automatically +""" +trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode) + params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) function params!(p::Params, x, seen = IdSet()) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index f9d4849a..ed2879b0 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -84,19 +84,19 @@ end @test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7) end - let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:6), 3, 2, 1) + let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1) y = reshape(permutedims(x, [2, 1, 3]), 2, :) y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3]) @test m(x) == y end - let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:12), 2, 3, 2, 1) + let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1) y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) @test m(x) == y end - let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) + let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :) y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5]) @test m(x) == y @@ -164,7 +164,7 @@ end @test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7) end - let m = testmode!(InstanceNorm(2), false), sizes = (2, 4, 1, 2, 3), + let m = trainmode!(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(m(y), sizes...) @@ -181,7 +181,7 @@ end end # show that instance norm is equal to batch norm when channel and batch dims are squashed - let m_inorm = testmode!(InstanceNorm(2), false), m_bnorm = testmode!(BatchNorm(12), false), sizes = (5, 5, 3, 4, 2, 6), + let m_inorm = trainmode!(InstanceNorm(2)), m_bnorm = trainmode!(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6), x = reshape(Float32.(collect(1:prod(sizes))), sizes) @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes) end @@ -265,7 +265,7 @@ if VERSION >= v"1.1" @test isapprox(y, out, atol = 1.0e-7) end - let m = testmode!(GroupNorm(2,2), false), sizes = (2, 4, 1, 2, 3), + let m = trainmode!(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(m(y), sizes...) @@ -282,13 +282,13 @@ if VERSION >= v"1.1" end # show that group norm is the same as instance norm when the group size is the same as the number of channels - let IN = testmode!(InstanceNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,5), + let IN = trainmode!(InstanceNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) @test IN(x) ≈ GN(x) end # show that group norm is the same as batch norm for a group of size 1 and batch of size 1 - let BN = testmode!(BatchNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,1), + let BN = trainmode!(BatchNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,1), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) @test BN(x) ≈ GN(x) end