1044: Add testmode! back for normalization layers r=CarloLucibello a=darsnack

Fixed #909 

I added `testmode!(m, mode)` back to Flux as per v0.9. Now the `mode` can be `false`, `true`, or `:auto`/`nothing` with the default being `:auto` for newly constructed layers. In `:auto` mode, the `istraining()` functions added in v0.10 are used to determine whether we are evaluating within an AD trace or not.

Also plan on adding a doc section in an additional commit.

Co-authored-by: Kyle Daruwalla <daruwalla@wisc.edu>
This commit is contained in:
bors[bot] 2020-03-01 19:14:07 +00:00 committed by GitHub
commit 069d228693
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 114 additions and 34 deletions

View File

@ -54,6 +54,15 @@ LayerNorm
GroupNorm GroupNorm
``` ```
### Testmode
Many normalisation layers behave differently under training and inference (testing). By default, Flux will automatically determine when a layer evaluation is part of training or inference. Still, depending on your use case, it may be helpful to manually specify when these layers should be treated as being trained or not. For this, Flux provides `testmode!`. When called on a model (e.g. a layer or chain of layers), this function will place the model into the mode specified.
```@docs
testmode!
trainmode!
```
## Cost Functions ## Cost Functions
```@docs ```@docs
Flux.mse Flux.mse

View File

@ -12,7 +12,7 @@ export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection, params, fmap, cpu, gpu, f32, f64 SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode!
include("optimise/Optimise.jl") include("optimise/Optimise.jl")
using .Optimise using .Optimise

View File

@ -39,6 +39,38 @@ end
trainable(m) = functor(m)[1] trainable(m) = functor(m)[1]
"""
testmode!(m, mode = true)
Set a layer or model's test mode (see below).
Using `:auto` mode will treat any gradient computation as training.
_Note_: if you manually set a model into test mode, you need to manually place
it back into train mode during training phase.
Possible values include:
- `false` for training
- `true` for testing
- `:auto` or `nothing` for Flux to detect the mode automatically
"""
testmode!(m, mode = true) = m
"""
trainmode!(m, mode = true)
Set a layer of model's train mode (see below).
Symmetric to [`testmode!`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)).
_Note_: if you manually set a model into train mode, you need to manually place
it into test mode during testing phase.
Possible values include:
- `true` for training
- `false` for testing
- `:auto` or `nothing` for Flux to detect the mode automatically
"""
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
function params!(p::Params, x, seen = IdSet()) function params!(p::Params, x, seen = IdSet())

View File

@ -33,6 +33,8 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
testmode!(m::Chain, mode = true) = (map(x -> testmode!(x, mode), m.layers); m)
function Base.show(io::IO, c::Chain) function Base.show(io::IO, c::Chain)
print(io, "Chain(") print(io, "Chain(")
join(io, c.layers, ", ") join(io, c.layers, ", ")

View File

@ -2,6 +2,8 @@ istraining() = false
@adjoint istraining() = true, _ -> nothing @adjoint istraining() = true, _ -> nothing
_isactive(m) = isnothing(m.active) ? istraining() : m.active
_dropout_shape(s, ::Colon) = size(s) _dropout_shape(s, ::Colon) = size(s)
_dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...) _dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...)
@ -29,18 +31,27 @@ end
Dropout(p, dims = :) Dropout(p, dims = :)
A Dropout layer. In the forward pass, applies the [`dropout`](@ref) function on the input. A Dropout layer. In the forward pass, applies the [`dropout`](@ref) function on the input.
Does nothing to the input once [`testmode!`](@ref) is false.
""" """
mutable struct Dropout{F,D} mutable struct Dropout{F,D}
p::F p::F
dims::D dims::D
active::Union{Bool, Nothing}
end end
function Dropout(p; dims = :) function Dropout(p; dims = :)
@assert 0 p 1 @assert 0 p 1
Dropout{typeof(p),typeof(dims)}(p, dims) Dropout{typeof(p),typeof(dims)}(p, dims, nothing)
end end
(a::Dropout)(x) = dropout(x, a.p; dims = a.dims) function (a::Dropout)(x)
_isactive(a) || return x
return dropout(x, a.p; dims = a.dims)
end
testmode!(m::Dropout, mode = true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
function Base.show(io::IO, d::Dropout) function Base.show(io::IO, d::Dropout)
print(io, "Dropout(", d.p) print(io, "Dropout(", d.p)
@ -54,17 +65,20 @@ end
A dropout layer. It is used in Self-Normalizing Neural Networks. A dropout layer. It is used in Self-Normalizing Neural Networks.
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf) (https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
The AlphaDropout layer ensures that mean and variance of activations remains the same as before. The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
Does nothing to the input once [`testmode!`](@ref) is false.
""" """
mutable struct AlphaDropout{F} mutable struct AlphaDropout{F}
p::F p::F
function AlphaDropout(p) active::Union{Bool, Nothing}
function AlphaDropout(p, active = nothing)
@assert 0 p 1 @assert 0 p 1
new{typeof(p)}(p) new{typeof(p)}(p, active)
end end
end end
function (a::AlphaDropout)(x) function (a::AlphaDropout)(x)
istraining() || return x _isactive(a) || return x
λ = eltype(x)(1.0507009873554804934193349852946) λ = eltype(x)(1.0507009873554804934193349852946)
α = eltype(x)(1.6732632423543772848170429916717) α = eltype(x)(1.6732632423543772848170429916717)
α1 = eltype(x)(-λ*α) α1 = eltype(x)(-λ*α)
@ -76,6 +90,9 @@ function (a::AlphaDropout)(x)
return x return x
end end
testmode!(m::AlphaDropout, mode = true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
""" """
LayerNorm(h::Integer) LayerNorm(h::Integer)
@ -114,6 +131,8 @@ it's the usual channel dimension.)
shifts them to have a new mean and variance (corresponding to the learnable, shifts them to have a new mean and variance (corresponding to the learnable,
per-channel `bias` and `scale` parameters). per-channel `bias` and `scale` parameters).
Use [`testmode!`](@ref) during inference.
See [Batch Normalization: Accelerating Deep Network Training by Reducing See [Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf). Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
@ -135,12 +154,13 @@ mutable struct BatchNorm{F,V,W,N}
σ²::W # moving std σ²::W # moving std
ϵ::N ϵ::N
momentum::N momentum::N
active::Union{Bool, Nothing}
end end
BatchNorm(chs::Integer, λ = identity; BatchNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
BatchNorm(λ, initβ(chs), initγ(chs), BatchNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum) zeros(chs), ones(chs), ϵ, momentum, nothing)
trainable(bn::BatchNorm) = (bn.β, bn.γ) trainable(bn::BatchNorm) = (bn.β, bn.γ)
@ -153,7 +173,7 @@ function (BN::BatchNorm)(x)
m = div(prod(size(x)), channels) m = div(prod(size(x)), channels)
γ = reshape(BN.γ, affine_shape...) γ = reshape(BN.γ, affine_shape...)
β = reshape(BN.β, affine_shape...) β = reshape(BN.β, affine_shape...)
if !istraining() if !_isactive(BN)
μ = reshape(BN.μ, affine_shape...) μ = reshape(BN.μ, affine_shape...)
σ² = reshape(BN.σ², affine_shape...) σ² = reshape(BN.σ², affine_shape...)
ϵ = BN.ϵ ϵ = BN.ϵ
@ -178,6 +198,9 @@ end
@functor BatchNorm @functor BatchNorm
testmode!(m::BatchNorm, mode = true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
function Base.show(io::IO, l::BatchNorm) function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(join(size(l.β), ", "))") print(io, "BatchNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)") (l.λ == identity) || print(io, ", λ = $(l.λ)")
@ -201,6 +224,8 @@ it's the usual channel dimension.)
shifts them to have a new mean and variance (corresponding to the learnable, shifts them to have a new mean and variance (corresponding to the learnable,
per-channel `bias` and `scale` parameters). per-channel `bias` and `scale` parameters).
Use [`testmode!`](@ref) during inference.
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022). See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
Example: Example:
@ -223,12 +248,13 @@ mutable struct InstanceNorm{F,V,W,N}
σ²::W # moving std σ²::W # moving std
ϵ::N ϵ::N
momentum::N momentum::N
active::Union{Bool, Nothing}
end end
InstanceNorm(chs::Integer, λ = identity; InstanceNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
InstanceNorm(λ, initβ(chs), initγ(chs), InstanceNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum) zeros(chs), ones(chs), ϵ, momentum, nothing)
trainable(in::InstanceNorm) = (in.β, in.γ) trainable(in::InstanceNorm) = (in.β, in.γ)
@ -245,7 +271,7 @@ function (in::InstanceNorm)(x)
m = div(prod(size(x)), c*bs) m = div(prod(size(x)), c*bs)
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape) γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
if !istraining() if !_isactive(in)
μ = expand_inst(in.μ, affine_shape) μ = expand_inst(in.μ, affine_shape)
σ² = expand_inst(in.σ², affine_shape) σ² = expand_inst(in.σ², affine_shape)
ϵ = in.ϵ ϵ = in.ϵ
@ -271,6 +297,9 @@ end
@functor InstanceNorm @functor InstanceNorm
testmode!(m::InstanceNorm, mode = true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
function Base.show(io::IO, l::InstanceNorm) function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(join(size(l.β), ", "))") print(io, "InstanceNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)") (l.λ == identity) || print(io, ", λ = $(l.λ)")
@ -291,6 +320,8 @@ For an array of N dimensions, the (N-1)th index is the channel dimension.
``G`` is the number of groups along which the statistics would be computed. ``G`` is the number of groups along which the statistics would be computed.
The number of channels must be an integer multiple of the number of groups. The number of channels must be an integer multiple of the number of groups.
Use [`testmode!`](@ref) during inference.
Example: Example:
``` ```
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1), m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
@ -308,12 +339,13 @@ mutable struct GroupNorm{F,V,W,N,T}
σ²::W # moving std σ²::W # moving std
ϵ::N ϵ::N
momentum::N momentum::N
active::Union{Bool, Nothing}
end end
GroupNorm(chs::Integer, G::Integer, λ = identity; GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
GroupNorm(G, λ, initβ(chs), initγ(chs), GroupNorm(G, λ, initβ(chs), initγ(chs),
zeros(G,1), ones(G,1), ϵ, momentum) zeros(G,1), ones(G,1), ϵ, momentum, nothing)
trainable(gn::GroupNorm) = (gn.β, gn.γ) trainable(gn::GroupNorm) = (gn.β, gn.γ)
@ -337,7 +369,7 @@ function(gn::GroupNorm)(x)
β = reshape(gn.β, affine_shape...) β = reshape(gn.β, affine_shape...)
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches)) y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
if !istraining() if !_isactive(gn)
og_shape = size(x) og_shape = size(x)
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1) μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1) σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
@ -368,6 +400,9 @@ end
@functor GroupNorm @functor GroupNorm
testmode!(m::GroupNorm, mode = true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
function Base.show(io::IO, l::GroupNorm) function Base.show(io::IO, l::GroupNorm)
print(io, "GroupNorm($(join(size(l.β), ", "))") print(io, "GroupNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)") (l.λ == identity) || print(io, ", λ = $(l.λ)")

View File

@ -1,30 +1,32 @@
using Flux, Test, Statistics using Flux, Test, Statistics
using Zygote: pullback using Zygote: pullback
trainmode(f, x...) = pullback(f, x...)[1] evalwgrad(f, x...) = pullback(f, x...)[1]
trainmode(f) = (x...) -> trainmode(f, x...)
@testset "Dropout" begin @testset "Dropout" begin
x = [1.,2.,3.] x = [1.,2.,3.]
@test x == Dropout(0.1)(x) @test x == Dropout(0.1)(x)
@test x == trainmode(Dropout(0), x) @test x == evalwgrad(Dropout(0), x)
@test zero(x) == trainmode(Dropout(1), x) @test zero(x) == evalwgrad(Dropout(1), x)
x = rand(100) x = rand(100)
m = Dropout(0.9) m = Dropout(0.9)
y = trainmode(m, x) y = evalwgrad(m, x)
@test count(a->a==0, y) > 50 @test count(a->a==0, y) > 50
y = m(x) testmode!(m, true)
y = evalwgrad(m, x) # should override istraining
@test count(a->a==0, y) == 0 @test count(a->a==0, y) == 0
y = trainmode(m, x) testmode!(m, false)
y = evalwgrad(m, x)
@test count(a->a==0, y) > 50 @test count(a->a==0, y) > 50
x = rand(Float32, 100) x = rand(Float32, 100)
m = Chain(Dense(100,100), m = Chain(Dense(100,100),
Dropout(0.9)) Dropout(0.9))
y = trainmode(m, x) y = evalwgrad(m, x)
@test count(a->a == 0, y) > 50 @test count(a->a == 0, y) > 50
y = m(x) testmode!(m, true)
y = evalwgrad(m, x) # should override istraining
@test count(a->a == 0, y) == 0 @test count(a->a == 0, y) == 0
x = rand(100, 50) x = rand(100, 50)
@ -49,7 +51,7 @@ end
# initial m.σ is 1 # initial m.σ is 1
# initial m.μ is 0 # initial m.μ is 0
y = trainmode(m, x) y = evalwgrad(m, x)
@test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5) @test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5)
# julia> x # julia> x
# 2×3 Array{Float64,2}: # 2×3 Array{Float64,2}:
@ -82,19 +84,19 @@ end
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7) @test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
end end
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1) let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
y = reshape(permutedims(x, [2, 1, 3]), 2, :) y = reshape(permutedims(x, [2, 1, 3]), 2, :)
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3]) y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
@test m(x) == y @test m(x) == y
end end
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1) let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
@test m(x) == y @test m(x) == y
end end
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :) y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5]) y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
@test m(x) == y @test m(x) == y
@ -117,7 +119,7 @@ end
x = Float64.(x) x = Float64.(x)
@test m.β == [0, 0] # initβ(2) @test m.β == [0, 0] # initβ(2)
@test m.γ == [1, 1] # initγ(2) @test m.γ == [1, 1] # initγ(2)
y = trainmode(m, x) y = evalwgrad(m, x)
#julia> x #julia> x
#[:, :, 1] = #[:, :, 1] =
@ -162,7 +164,7 @@ end
@test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7) @test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
end end
let m = trainmode(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3), let m = trainmode!(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3),
x = Float32.(reshape(collect(1:prod(sizes)), sizes)) x = Float32.(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...) y = reshape(m(y), sizes...)
@ -172,14 +174,14 @@ end
# check that μ, σ², and the output are the correct size for higher rank tensors # check that μ, σ², and the output are the correct size for higher rank tensors
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6), let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
x = reshape(Float32.(collect(1:prod(sizes))), sizes) x = reshape(Float32.(collect(1:prod(sizes))), sizes)
y = trainmode(m, x) y = evalwgrad(m, x)
@test size(m.μ) == (sizes[end - 1], ) @test size(m.μ) == (sizes[end - 1], )
@test size(m.σ²) == (sizes[end - 1], ) @test size(m.σ²) == (sizes[end - 1], )
@test size(y) == sizes @test size(y) == sizes
end end
# show that instance norm is equal to batch norm when channel and batch dims are squashed # show that instance norm is equal to batch norm when channel and batch dims are squashed
let m_inorm = trainmode(InstanceNorm(2)), m_bnorm = trainmode(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6), let m_inorm = trainmode!(InstanceNorm(2)), m_bnorm = trainmode!(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6),
x = reshape(Float32.(collect(1:prod(sizes))), sizes) x = reshape(Float32.(collect(1:prod(sizes))), sizes)
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes) @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
end end
@ -204,7 +206,7 @@ if VERSION >= v"1.1"
@test m.β == [0, 0, 0, 0] # initβ(32) @test m.β == [0, 0, 0, 0] # initβ(32)
@test m.γ == [1, 1, 1, 1] # initγ(32) @test m.γ == [1, 1, 1, 1] # initγ(32)
y = trainmode(m, x) y = evalwgrad(m, x)
#julia> x #julia> x
#[:, :, 1] = #[:, :, 1] =
@ -263,7 +265,7 @@ if VERSION >= v"1.1"
@test isapprox(y, out, atol = 1.0e-7) @test isapprox(y, out, atol = 1.0e-7)
end end
let m = trainmode(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3), let m = trainmode!(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3),
x = Float32.(reshape(collect(1:prod(sizes)), sizes)) x = Float32.(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...) y = reshape(m(y), sizes...)
@ -273,20 +275,20 @@ if VERSION >= v"1.1"
# check that μ, σ², and the output are the correct size for higher rank tensors # check that μ, σ², and the output are the correct size for higher rank tensors
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6), let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
x = Float32.(reshape(collect(1:prod(sizes)), sizes)) x = Float32.(reshape(collect(1:prod(sizes)), sizes))
y = trainmode(m, x) y = evalwgrad(m, x)
@test size(m.μ) == (m.G,1) @test size(m.μ) == (m.G,1)
@test size(m.σ²) == (m.G,1) @test size(m.σ²) == (m.G,1)
@test size(y) == sizes @test size(y) == sizes
end end
# show that group norm is the same as instance norm when the group size is the same as the number of channels # show that group norm is the same as instance norm when the group size is the same as the number of channels
let IN = trainmode(InstanceNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,5), let IN = trainmode!(InstanceNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5),
x = Float32.(reshape(collect(1:prod(sizes)), sizes)) x = Float32.(reshape(collect(1:prod(sizes)), sizes))
@test IN(x) GN(x) @test IN(x) GN(x)
end end
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1 # show that group norm is the same as batch norm for a group of size 1 and batch of size 1
let BN = trainmode(BatchNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,1), let BN = trainmode!(BatchNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,1),
x = Float32.(reshape(collect(1:prod(sizes)), sizes)) x = Float32.(reshape(collect(1:prod(sizes)), sizes))
@test BN(x) GN(x) @test BN(x) GN(x)
end end