Added trainmode! and updated docs with warning

This commit is contained in:
Kyle Daruwalla 2020-03-01 12:30:41 -06:00
parent 568ecb1c97
commit c001d0f3c5
4 changed files with 30 additions and 10 deletions

View File

@ -72,6 +72,7 @@ Many normalisation layers behave differently under training and inference (testi
```@docs ```@docs
testmode! testmode!
trainmode!
``` ```
## Cost Functions ## Cost Functions

View File

@ -11,7 +11,7 @@ export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode! SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode!
include("optimise/Optimise.jl") include("optimise/Optimise.jl")
using .Optimise using .Optimise

View File

@ -40,11 +40,14 @@ end
trainable(m) = functor(m)[1] trainable(m) = functor(m)[1]
""" """
testmode!(m, mode = true) testmode!(m, mode = true)
Set a layer or model's test mode (see below). Set a layer or model's test mode (see below).
Using `:auto` mode will treat any gradient computation as training. Using `:auto` mode will treat any gradient computation as training.
_Note_: if you manually set a model into test mode, you need to manually place
it back into train mode.
Possible values include: Possible values include:
- `false` for training - `false` for training
- `true` for testing - `true` for testing
@ -52,6 +55,22 @@ Possible values include:
""" """
testmode!(m, mode = true) = m testmode!(m, mode = true) = m
"""
trainmode!(m, mode = true)
Set a layer of model's train mode (see below).
Symmetric to [`testmode`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)).
_Note_: if you manually set a model into train mode, you need to manually place
it into test mode.
Possible values include:
- `true` for training
- `false` for testing
- `:auto` or `nothing` for Flux to detect the mode automatically
"""
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
function params!(p::Params, x, seen = IdSet()) function params!(p::Params, x, seen = IdSet())

View File

@ -84,19 +84,19 @@ end
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7) @test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
end end
let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:6), 3, 2, 1) let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
y = reshape(permutedims(x, [2, 1, 3]), 2, :) y = reshape(permutedims(x, [2, 1, 3]), 2, :)
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3]) y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
@test m(x) == y @test m(x) == y
end end
let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:12), 2, 3, 2, 1) let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
@test m(x) == y @test m(x) == y
end end
let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :) y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5]) y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
@test m(x) == y @test m(x) == y
@ -164,7 +164,7 @@ end
@test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7) @test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
end end
let m = testmode!(InstanceNorm(2), false), sizes = (2, 4, 1, 2, 3), let m = trainmode!(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3),
x = Float32.(reshape(collect(1:prod(sizes)), sizes)) x = Float32.(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...) y = reshape(m(y), sizes...)
@ -181,7 +181,7 @@ end
end end
# show that instance norm is equal to batch norm when channel and batch dims are squashed # show that instance norm is equal to batch norm when channel and batch dims are squashed
let m_inorm = testmode!(InstanceNorm(2), false), m_bnorm = testmode!(BatchNorm(12), false), sizes = (5, 5, 3, 4, 2, 6), let m_inorm = trainmode!(InstanceNorm(2)), m_bnorm = trainmode!(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6),
x = reshape(Float32.(collect(1:prod(sizes))), sizes) x = reshape(Float32.(collect(1:prod(sizes))), sizes)
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes) @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
end end
@ -265,7 +265,7 @@ if VERSION >= v"1.1"
@test isapprox(y, out, atol = 1.0e-7) @test isapprox(y, out, atol = 1.0e-7)
end end
let m = testmode!(GroupNorm(2,2), false), sizes = (2, 4, 1, 2, 3), let m = trainmode!(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3),
x = Float32.(reshape(collect(1:prod(sizes)), sizes)) x = Float32.(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...) y = reshape(m(y), sizes...)
@ -282,13 +282,13 @@ if VERSION >= v"1.1"
end end
# show that group norm is the same as instance norm when the group size is the same as the number of channels # show that group norm is the same as instance norm when the group size is the same as the number of channels
let IN = testmode!(InstanceNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,5), let IN = trainmode!(InstanceNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5),
x = Float32.(reshape(collect(1:prod(sizes)), sizes)) x = Float32.(reshape(collect(1:prod(sizes)), sizes))
@test IN(x) GN(x) @test IN(x) GN(x)
end end
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1 # show that group norm is the same as batch norm for a group of size 1 and batch of size 1
let BN = testmode!(BatchNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,1), let BN = trainmode!(BatchNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,1),
x = Float32.(reshape(collect(1:prod(sizes)), sizes)) x = Float32.(reshape(collect(1:prod(sizes)), sizes))
@test BN(x) GN(x) @test BN(x) GN(x)
end end