Added trainmode! and updated docs with warning

This commit is contained in:
Kyle Daruwalla 2020-03-01 12:30:41 -06:00
parent 568ecb1c97
commit c001d0f3c5
4 changed files with 30 additions and 10 deletions

View File

@ -72,6 +72,7 @@ Many normalisation layers behave differently under training and inference (testi
```@docs
testmode!
trainmode!
```
## Cost Functions

View File

@ -11,7 +11,7 @@ export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode!
include("optimise/Optimise.jl")
using .Optimise

View File

@ -40,11 +40,14 @@ end
trainable(m) = functor(m)[1]
"""
testmode!(m, mode = true)
testmode!(m, mode = true)
Set a layer or model's test mode (see below).
Using `:auto` mode will treat any gradient computation as training.
_Note_: if you manually set a model into test mode, you need to manually place
it back into train mode.
Possible values include:
- `false` for training
- `true` for testing
@ -52,6 +55,22 @@ Possible values include:
"""
testmode!(m, mode = true) = m
"""
trainmode!(m, mode = true)
Set a layer of model's train mode (see below).
Symmetric to [`testmode`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)).
_Note_: if you manually set a model into train mode, you need to manually place
it into test mode.
Possible values include:
- `true` for training
- `false` for testing
- `:auto` or `nothing` for Flux to detect the mode automatically
"""
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
function params!(p::Params, x, seen = IdSet())

View File

@ -84,19 +84,19 @@ end
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
end
let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:6), 3, 2, 1)
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
@test m(x) == y
end
let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:12), 2, 3, 2, 1)
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
@test m(x) == y
end
let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
@test m(x) == y
@ -164,7 +164,7 @@ end
@test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
end
let m = testmode!(InstanceNorm(2), false), sizes = (2, 4, 1, 2, 3),
let m = trainmode!(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3),
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...)
@ -181,7 +181,7 @@ end
end
# show that instance norm is equal to batch norm when channel and batch dims are squashed
let m_inorm = testmode!(InstanceNorm(2), false), m_bnorm = testmode!(BatchNorm(12), false), sizes = (5, 5, 3, 4, 2, 6),
let m_inorm = trainmode!(InstanceNorm(2)), m_bnorm = trainmode!(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6),
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
end
@ -265,7 +265,7 @@ if VERSION >= v"1.1"
@test isapprox(y, out, atol = 1.0e-7)
end
let m = testmode!(GroupNorm(2,2), false), sizes = (2, 4, 1, 2, 3),
let m = trainmode!(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3),
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...)
@ -282,13 +282,13 @@ if VERSION >= v"1.1"
end
# show that group norm is the same as instance norm when the group size is the same as the number of channels
let IN = testmode!(InstanceNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,5),
let IN = trainmode!(InstanceNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5),
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
@test IN(x) GN(x)
end
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
let BN = testmode!(BatchNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,1),
let BN = trainmode!(BatchNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,1),
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
@test BN(x) GN(x)
end