Added trainmode! and updated docs with warning
This commit is contained in:
parent
568ecb1c97
commit
c001d0f3c5
|
@ -72,6 +72,7 @@ Many normalisation layers behave differently under training and inference (testi
|
|||
|
||||
```@docs
|
||||
testmode!
|
||||
trainmode!
|
||||
```
|
||||
|
||||
## Cost Functions
|
||||
|
|
|
@ -11,7 +11,7 @@ export gradient
|
|||
|
||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!
|
||||
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode!
|
||||
|
||||
include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
|
|
|
@ -40,11 +40,14 @@ end
|
|||
trainable(m) = functor(m)[1]
|
||||
|
||||
"""
|
||||
testmode!(m, mode = true)
|
||||
testmode!(m, mode = true)
|
||||
|
||||
Set a layer or model's test mode (see below).
|
||||
Using `:auto` mode will treat any gradient computation as training.
|
||||
|
||||
_Note_: if you manually set a model into test mode, you need to manually place
|
||||
it back into train mode.
|
||||
|
||||
Possible values include:
|
||||
- `false` for training
|
||||
- `true` for testing
|
||||
|
@ -52,6 +55,22 @@ Possible values include:
|
|||
"""
|
||||
testmode!(m, mode = true) = m
|
||||
|
||||
"""
|
||||
trainmode!(m, mode = true)
|
||||
|
||||
Set a layer of model's train mode (see below).
|
||||
Symmetric to [`testmode`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)).
|
||||
|
||||
_Note_: if you manually set a model into train mode, you need to manually place
|
||||
it into test mode.
|
||||
|
||||
Possible values include:
|
||||
- `true` for training
|
||||
- `false` for testing
|
||||
- `:auto` or `nothing` for Flux to detect the mode automatically
|
||||
"""
|
||||
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)
|
||||
|
||||
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
|
||||
|
||||
function params!(p::Params, x, seen = IdSet())
|
||||
|
|
|
@ -84,19 +84,19 @@ end
|
|||
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:6), 3, 2, 1)
|
||||
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
|
||||
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:12), 2, 3, 2, 1)
|
||||
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
|
||||
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
|
||||
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
||||
@test m(x) == y
|
||||
|
@ -164,7 +164,7 @@ end
|
|||
@test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = testmode!(InstanceNorm(2), false), sizes = (2, 4, 1, 2, 3),
|
||||
let m = trainmode!(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||
y = reshape(m(y), sizes...)
|
||||
|
@ -181,7 +181,7 @@ end
|
|||
end
|
||||
|
||||
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
||||
let m_inorm = testmode!(InstanceNorm(2), false), m_bnorm = testmode!(BatchNorm(12), false), sizes = (5, 5, 3, 4, 2, 6),
|
||||
let m_inorm = trainmode!(InstanceNorm(2)), m_bnorm = trainmode!(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
||||
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
||||
end
|
||||
|
@ -265,7 +265,7 @@ if VERSION >= v"1.1"
|
|||
@test isapprox(y, out, atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = testmode!(GroupNorm(2,2), false), sizes = (2, 4, 1, 2, 3),
|
||||
let m = trainmode!(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||
y = reshape(m(y), sizes...)
|
||||
|
@ -282,13 +282,13 @@ if VERSION >= v"1.1"
|
|||
end
|
||||
|
||||
# show that group norm is the same as instance norm when the group size is the same as the number of channels
|
||||
let IN = testmode!(InstanceNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,5),
|
||||
let IN = trainmode!(InstanceNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test IN(x) ≈ GN(x)
|
||||
end
|
||||
|
||||
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
||||
let BN = testmode!(BatchNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,1),
|
||||
let BN = trainmode!(BatchNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,1),
|
||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test BN(x) ≈ GN(x)
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue