Added trainmode! and updated docs with warning
This commit is contained in:
parent
568ecb1c97
commit
c001d0f3c5
|
@ -72,6 +72,7 @@ Many normalisation layers behave differently under training and inference (testi
|
||||||
|
|
||||||
```@docs
|
```@docs
|
||||||
testmode!
|
testmode!
|
||||||
|
trainmode!
|
||||||
```
|
```
|
||||||
|
|
||||||
## Cost Functions
|
## Cost Functions
|
||||||
|
|
|
@ -11,7 +11,7 @@ export gradient
|
||||||
|
|
||||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
||||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||||
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!
|
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode!
|
||||||
|
|
||||||
include("optimise/Optimise.jl")
|
include("optimise/Optimise.jl")
|
||||||
using .Optimise
|
using .Optimise
|
||||||
|
|
|
@ -40,11 +40,14 @@ end
|
||||||
trainable(m) = functor(m)[1]
|
trainable(m) = functor(m)[1]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
testmode!(m, mode = true)
|
testmode!(m, mode = true)
|
||||||
|
|
||||||
Set a layer or model's test mode (see below).
|
Set a layer or model's test mode (see below).
|
||||||
Using `:auto` mode will treat any gradient computation as training.
|
Using `:auto` mode will treat any gradient computation as training.
|
||||||
|
|
||||||
|
_Note_: if you manually set a model into test mode, you need to manually place
|
||||||
|
it back into train mode.
|
||||||
|
|
||||||
Possible values include:
|
Possible values include:
|
||||||
- `false` for training
|
- `false` for training
|
||||||
- `true` for testing
|
- `true` for testing
|
||||||
|
@ -52,6 +55,22 @@ Possible values include:
|
||||||
"""
|
"""
|
||||||
testmode!(m, mode = true) = m
|
testmode!(m, mode = true) = m
|
||||||
|
|
||||||
|
"""
|
||||||
|
trainmode!(m, mode = true)
|
||||||
|
|
||||||
|
Set a layer of model's train mode (see below).
|
||||||
|
Symmetric to [`testmode`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)).
|
||||||
|
|
||||||
|
_Note_: if you manually set a model into train mode, you need to manually place
|
||||||
|
it into test mode.
|
||||||
|
|
||||||
|
Possible values include:
|
||||||
|
- `true` for training
|
||||||
|
- `false` for testing
|
||||||
|
- `:auto` or `nothing` for Flux to detect the mode automatically
|
||||||
|
"""
|
||||||
|
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)
|
||||||
|
|
||||||
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
|
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
|
||||||
|
|
||||||
function params!(p::Params, x, seen = IdSet())
|
function params!(p::Params, x, seen = IdSet())
|
||||||
|
|
|
@ -84,19 +84,19 @@ end
|
||||||
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
|
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:6), 3, 2, 1)
|
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
|
||||||
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
|
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
|
||||||
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
|
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
|
||||||
@test m(x) == y
|
@test m(x) == y
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:12), 2, 3, 2, 1)
|
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
|
||||||
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
||||||
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
||||||
@test m(x) == y
|
@test m(x) == y
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
|
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
|
||||||
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
|
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
|
||||||
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
||||||
@test m(x) == y
|
@test m(x) == y
|
||||||
|
@ -164,7 +164,7 @@ end
|
||||||
@test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
|
@test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = testmode!(InstanceNorm(2), false), sizes = (2, 4, 1, 2, 3),
|
let m = trainmode!(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3),
|
||||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||||
y = reshape(m(y), sizes...)
|
y = reshape(m(y), sizes...)
|
||||||
|
@ -181,7 +181,7 @@ end
|
||||||
end
|
end
|
||||||
|
|
||||||
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
||||||
let m_inorm = testmode!(InstanceNorm(2), false), m_bnorm = testmode!(BatchNorm(12), false), sizes = (5, 5, 3, 4, 2, 6),
|
let m_inorm = trainmode!(InstanceNorm(2)), m_bnorm = trainmode!(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6),
|
||||||
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
||||||
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
||||||
end
|
end
|
||||||
|
@ -265,7 +265,7 @@ if VERSION >= v"1.1"
|
||||||
@test isapprox(y, out, atol = 1.0e-7)
|
@test isapprox(y, out, atol = 1.0e-7)
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = testmode!(GroupNorm(2,2), false), sizes = (2, 4, 1, 2, 3),
|
let m = trainmode!(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3),
|
||||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||||
y = reshape(m(y), sizes...)
|
y = reshape(m(y), sizes...)
|
||||||
|
@ -282,13 +282,13 @@ if VERSION >= v"1.1"
|
||||||
end
|
end
|
||||||
|
|
||||||
# show that group norm is the same as instance norm when the group size is the same as the number of channels
|
# show that group norm is the same as instance norm when the group size is the same as the number of channels
|
||||||
let IN = testmode!(InstanceNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,5),
|
let IN = trainmode!(InstanceNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5),
|
||||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
@test IN(x) ≈ GN(x)
|
@test IN(x) ≈ GN(x)
|
||||||
end
|
end
|
||||||
|
|
||||||
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
||||||
let BN = testmode!(BatchNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,1),
|
let BN = trainmode!(BatchNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,1),
|
||||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
@test BN(x) ≈ GN(x)
|
@test BN(x) ≈ GN(x)
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue