Merge #1044
1044: Add testmode! back for normalization layers r=CarloLucibello a=darsnack Fixed #909 I added `testmode!(m, mode)` back to Flux as per v0.9. Now the `mode` can be `false`, `true`, or `:auto`/`nothing` with the default being `:auto` for newly constructed layers. In `:auto` mode, the `istraining()` functions added in v0.10 are used to determine whether we are evaluating within an AD trace or not. Also plan on adding a doc section in an additional commit. Co-authored-by: Kyle Daruwalla <daruwalla@wisc.edu>
This commit is contained in:
commit
069d228693
|
@ -54,6 +54,15 @@ LayerNorm
|
||||||
GroupNorm
|
GroupNorm
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Testmode
|
||||||
|
|
||||||
|
Many normalisation layers behave differently under training and inference (testing). By default, Flux will automatically determine when a layer evaluation is part of training or inference. Still, depending on your use case, it may be helpful to manually specify when these layers should be treated as being trained or not. For this, Flux provides `testmode!`. When called on a model (e.g. a layer or chain of layers), this function will place the model into the mode specified.
|
||||||
|
|
||||||
|
```@docs
|
||||||
|
testmode!
|
||||||
|
trainmode!
|
||||||
|
```
|
||||||
|
|
||||||
## Cost Functions
|
## Cost Functions
|
||||||
```@docs
|
```@docs
|
||||||
Flux.mse
|
Flux.mse
|
||||||
|
|
|
@ -12,7 +12,7 @@ export gradient
|
||||||
|
|
||||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
||||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||||
SkipConnection, params, fmap, cpu, gpu, f32, f64
|
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode!
|
||||||
|
|
||||||
include("optimise/Optimise.jl")
|
include("optimise/Optimise.jl")
|
||||||
using .Optimise
|
using .Optimise
|
||||||
|
|
|
@ -39,6 +39,38 @@ end
|
||||||
|
|
||||||
trainable(m) = functor(m)[1]
|
trainable(m) = functor(m)[1]
|
||||||
|
|
||||||
|
"""
|
||||||
|
testmode!(m, mode = true)
|
||||||
|
|
||||||
|
Set a layer or model's test mode (see below).
|
||||||
|
Using `:auto` mode will treat any gradient computation as training.
|
||||||
|
|
||||||
|
_Note_: if you manually set a model into test mode, you need to manually place
|
||||||
|
it back into train mode during training phase.
|
||||||
|
|
||||||
|
Possible values include:
|
||||||
|
- `false` for training
|
||||||
|
- `true` for testing
|
||||||
|
- `:auto` or `nothing` for Flux to detect the mode automatically
|
||||||
|
"""
|
||||||
|
testmode!(m, mode = true) = m
|
||||||
|
|
||||||
|
"""
|
||||||
|
trainmode!(m, mode = true)
|
||||||
|
|
||||||
|
Set a layer of model's train mode (see below).
|
||||||
|
Symmetric to [`testmode!`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)).
|
||||||
|
|
||||||
|
_Note_: if you manually set a model into train mode, you need to manually place
|
||||||
|
it into test mode during testing phase.
|
||||||
|
|
||||||
|
Possible values include:
|
||||||
|
- `true` for training
|
||||||
|
- `false` for testing
|
||||||
|
- `:auto` or `nothing` for Flux to detect the mode automatically
|
||||||
|
"""
|
||||||
|
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)
|
||||||
|
|
||||||
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
|
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
|
||||||
|
|
||||||
function params!(p::Params, x, seen = IdSet())
|
function params!(p::Params, x, seen = IdSet())
|
||||||
|
|
|
@ -33,6 +33,8 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
|
||||||
|
|
||||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||||
|
|
||||||
|
testmode!(m::Chain, mode = true) = (map(x -> testmode!(x, mode), m.layers); m)
|
||||||
|
|
||||||
function Base.show(io::IO, c::Chain)
|
function Base.show(io::IO, c::Chain)
|
||||||
print(io, "Chain(")
|
print(io, "Chain(")
|
||||||
join(io, c.layers, ", ")
|
join(io, c.layers, ", ")
|
||||||
|
|
|
@ -2,6 +2,8 @@ istraining() = false
|
||||||
|
|
||||||
@adjoint istraining() = true, _ -> nothing
|
@adjoint istraining() = true, _ -> nothing
|
||||||
|
|
||||||
|
_isactive(m) = isnothing(m.active) ? istraining() : m.active
|
||||||
|
|
||||||
_dropout_shape(s, ::Colon) = size(s)
|
_dropout_shape(s, ::Colon) = size(s)
|
||||||
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...)
|
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...)
|
||||||
|
|
||||||
|
@ -29,18 +31,27 @@ end
|
||||||
Dropout(p, dims = :)
|
Dropout(p, dims = :)
|
||||||
|
|
||||||
A Dropout layer. In the forward pass, applies the [`dropout`](@ref) function on the input.
|
A Dropout layer. In the forward pass, applies the [`dropout`](@ref) function on the input.
|
||||||
|
|
||||||
|
Does nothing to the input once [`testmode!`](@ref) is false.
|
||||||
"""
|
"""
|
||||||
mutable struct Dropout{F,D}
|
mutable struct Dropout{F,D}
|
||||||
p::F
|
p::F
|
||||||
dims::D
|
dims::D
|
||||||
|
active::Union{Bool, Nothing}
|
||||||
end
|
end
|
||||||
|
|
||||||
function Dropout(p; dims = :)
|
function Dropout(p; dims = :)
|
||||||
@assert 0 ≤ p ≤ 1
|
@assert 0 ≤ p ≤ 1
|
||||||
Dropout{typeof(p),typeof(dims)}(p, dims)
|
Dropout{typeof(p),typeof(dims)}(p, dims, nothing)
|
||||||
end
|
end
|
||||||
|
|
||||||
(a::Dropout)(x) = dropout(x, a.p; dims = a.dims)
|
function (a::Dropout)(x)
|
||||||
|
_isactive(a) || return x
|
||||||
|
return dropout(x, a.p; dims = a.dims)
|
||||||
|
end
|
||||||
|
|
||||||
|
testmode!(m::Dropout, mode = true) =
|
||||||
|
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||||
|
|
||||||
function Base.show(io::IO, d::Dropout)
|
function Base.show(io::IO, d::Dropout)
|
||||||
print(io, "Dropout(", d.p)
|
print(io, "Dropout(", d.p)
|
||||||
|
@ -54,17 +65,20 @@ end
|
||||||
A dropout layer. It is used in Self-Normalizing Neural Networks.
|
A dropout layer. It is used in Self-Normalizing Neural Networks.
|
||||||
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
|
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
|
||||||
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
|
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
|
||||||
|
|
||||||
|
Does nothing to the input once [`testmode!`](@ref) is false.
|
||||||
"""
|
"""
|
||||||
mutable struct AlphaDropout{F}
|
mutable struct AlphaDropout{F}
|
||||||
p::F
|
p::F
|
||||||
function AlphaDropout(p)
|
active::Union{Bool, Nothing}
|
||||||
|
function AlphaDropout(p, active = nothing)
|
||||||
@assert 0 ≤ p ≤ 1
|
@assert 0 ≤ p ≤ 1
|
||||||
new{typeof(p)}(p)
|
new{typeof(p)}(p, active)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function (a::AlphaDropout)(x)
|
function (a::AlphaDropout)(x)
|
||||||
istraining() || return x
|
_isactive(a) || return x
|
||||||
λ = eltype(x)(1.0507009873554804934193349852946)
|
λ = eltype(x)(1.0507009873554804934193349852946)
|
||||||
α = eltype(x)(1.6732632423543772848170429916717)
|
α = eltype(x)(1.6732632423543772848170429916717)
|
||||||
α1 = eltype(x)(-λ*α)
|
α1 = eltype(x)(-λ*α)
|
||||||
|
@ -76,6 +90,9 @@ function (a::AlphaDropout)(x)
|
||||||
return x
|
return x
|
||||||
end
|
end
|
||||||
|
|
||||||
|
testmode!(m::AlphaDropout, mode = true) =
|
||||||
|
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
LayerNorm(h::Integer)
|
LayerNorm(h::Integer)
|
||||||
|
|
||||||
|
@ -114,6 +131,8 @@ it's the usual channel dimension.)
|
||||||
shifts them to have a new mean and variance (corresponding to the learnable,
|
shifts them to have a new mean and variance (corresponding to the learnable,
|
||||||
per-channel `bias` and `scale` parameters).
|
per-channel `bias` and `scale` parameters).
|
||||||
|
|
||||||
|
Use [`testmode!`](@ref) during inference.
|
||||||
|
|
||||||
See [Batch Normalization: Accelerating Deep Network Training by Reducing
|
See [Batch Normalization: Accelerating Deep Network Training by Reducing
|
||||||
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
|
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
|
||||||
|
|
||||||
|
@ -135,12 +154,13 @@ mutable struct BatchNorm{F,V,W,N}
|
||||||
σ²::W # moving std
|
σ²::W # moving std
|
||||||
ϵ::N
|
ϵ::N
|
||||||
momentum::N
|
momentum::N
|
||||||
|
active::Union{Bool, Nothing}
|
||||||
end
|
end
|
||||||
|
|
||||||
BatchNorm(chs::Integer, λ = identity;
|
BatchNorm(chs::Integer, λ = identity;
|
||||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||||
BatchNorm(λ, initβ(chs), initγ(chs),
|
BatchNorm(λ, initβ(chs), initγ(chs),
|
||||||
zeros(chs), ones(chs), ϵ, momentum)
|
zeros(chs), ones(chs), ϵ, momentum, nothing)
|
||||||
|
|
||||||
trainable(bn::BatchNorm) = (bn.β, bn.γ)
|
trainable(bn::BatchNorm) = (bn.β, bn.γ)
|
||||||
|
|
||||||
|
@ -153,7 +173,7 @@ function (BN::BatchNorm)(x)
|
||||||
m = div(prod(size(x)), channels)
|
m = div(prod(size(x)), channels)
|
||||||
γ = reshape(BN.γ, affine_shape...)
|
γ = reshape(BN.γ, affine_shape...)
|
||||||
β = reshape(BN.β, affine_shape...)
|
β = reshape(BN.β, affine_shape...)
|
||||||
if !istraining()
|
if !_isactive(BN)
|
||||||
μ = reshape(BN.μ, affine_shape...)
|
μ = reshape(BN.μ, affine_shape...)
|
||||||
σ² = reshape(BN.σ², affine_shape...)
|
σ² = reshape(BN.σ², affine_shape...)
|
||||||
ϵ = BN.ϵ
|
ϵ = BN.ϵ
|
||||||
|
@ -178,6 +198,9 @@ end
|
||||||
|
|
||||||
@functor BatchNorm
|
@functor BatchNorm
|
||||||
|
|
||||||
|
testmode!(m::BatchNorm, mode = true) =
|
||||||
|
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||||
|
|
||||||
function Base.show(io::IO, l::BatchNorm)
|
function Base.show(io::IO, l::BatchNorm)
|
||||||
print(io, "BatchNorm($(join(size(l.β), ", "))")
|
print(io, "BatchNorm($(join(size(l.β), ", "))")
|
||||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||||
|
@ -201,6 +224,8 @@ it's the usual channel dimension.)
|
||||||
shifts them to have a new mean and variance (corresponding to the learnable,
|
shifts them to have a new mean and variance (corresponding to the learnable,
|
||||||
per-channel `bias` and `scale` parameters).
|
per-channel `bias` and `scale` parameters).
|
||||||
|
|
||||||
|
Use [`testmode!`](@ref) during inference.
|
||||||
|
|
||||||
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
|
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
@ -223,12 +248,13 @@ mutable struct InstanceNorm{F,V,W,N}
|
||||||
σ²::W # moving std
|
σ²::W # moving std
|
||||||
ϵ::N
|
ϵ::N
|
||||||
momentum::N
|
momentum::N
|
||||||
|
active::Union{Bool, Nothing}
|
||||||
end
|
end
|
||||||
|
|
||||||
InstanceNorm(chs::Integer, λ = identity;
|
InstanceNorm(chs::Integer, λ = identity;
|
||||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||||
InstanceNorm(λ, initβ(chs), initγ(chs),
|
InstanceNorm(λ, initβ(chs), initγ(chs),
|
||||||
zeros(chs), ones(chs), ϵ, momentum)
|
zeros(chs), ones(chs), ϵ, momentum, nothing)
|
||||||
|
|
||||||
trainable(in::InstanceNorm) = (in.β, in.γ)
|
trainable(in::InstanceNorm) = (in.β, in.γ)
|
||||||
|
|
||||||
|
@ -245,7 +271,7 @@ function (in::InstanceNorm)(x)
|
||||||
m = div(prod(size(x)), c*bs)
|
m = div(prod(size(x)), c*bs)
|
||||||
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
|
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
|
||||||
|
|
||||||
if !istraining()
|
if !_isactive(in)
|
||||||
μ = expand_inst(in.μ, affine_shape)
|
μ = expand_inst(in.μ, affine_shape)
|
||||||
σ² = expand_inst(in.σ², affine_shape)
|
σ² = expand_inst(in.σ², affine_shape)
|
||||||
ϵ = in.ϵ
|
ϵ = in.ϵ
|
||||||
|
@ -271,6 +297,9 @@ end
|
||||||
|
|
||||||
@functor InstanceNorm
|
@functor InstanceNorm
|
||||||
|
|
||||||
|
testmode!(m::InstanceNorm, mode = true) =
|
||||||
|
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||||
|
|
||||||
function Base.show(io::IO, l::InstanceNorm)
|
function Base.show(io::IO, l::InstanceNorm)
|
||||||
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
||||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||||
|
@ -291,6 +320,8 @@ For an array of N dimensions, the (N-1)th index is the channel dimension.
|
||||||
``G`` is the number of groups along which the statistics would be computed.
|
``G`` is the number of groups along which the statistics would be computed.
|
||||||
The number of channels must be an integer multiple of the number of groups.
|
The number of channels must be an integer multiple of the number of groups.
|
||||||
|
|
||||||
|
Use [`testmode!`](@ref) during inference.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```
|
```
|
||||||
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
|
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
|
||||||
|
@ -308,12 +339,13 @@ mutable struct GroupNorm{F,V,W,N,T}
|
||||||
σ²::W # moving std
|
σ²::W # moving std
|
||||||
ϵ::N
|
ϵ::N
|
||||||
momentum::N
|
momentum::N
|
||||||
|
active::Union{Bool, Nothing}
|
||||||
end
|
end
|
||||||
|
|
||||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||||
GroupNorm(G, λ, initβ(chs), initγ(chs),
|
GroupNorm(G, λ, initβ(chs), initγ(chs),
|
||||||
zeros(G,1), ones(G,1), ϵ, momentum)
|
zeros(G,1), ones(G,1), ϵ, momentum, nothing)
|
||||||
|
|
||||||
trainable(gn::GroupNorm) = (gn.β, gn.γ)
|
trainable(gn::GroupNorm) = (gn.β, gn.γ)
|
||||||
|
|
||||||
|
@ -337,7 +369,7 @@ function(gn::GroupNorm)(x)
|
||||||
β = reshape(gn.β, affine_shape...)
|
β = reshape(gn.β, affine_shape...)
|
||||||
|
|
||||||
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
||||||
if !istraining()
|
if !_isactive(gn)
|
||||||
og_shape = size(x)
|
og_shape = size(x)
|
||||||
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||||
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||||
|
@ -368,6 +400,9 @@ end
|
||||||
|
|
||||||
@functor GroupNorm
|
@functor GroupNorm
|
||||||
|
|
||||||
|
testmode!(m::GroupNorm, mode = true) =
|
||||||
|
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||||
|
|
||||||
function Base.show(io::IO, l::GroupNorm)
|
function Base.show(io::IO, l::GroupNorm)
|
||||||
print(io, "GroupNorm($(join(size(l.β), ", "))")
|
print(io, "GroupNorm($(join(size(l.β), ", "))")
|
||||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||||
|
|
|
@ -1,30 +1,32 @@
|
||||||
using Flux, Test, Statistics
|
using Flux, Test, Statistics
|
||||||
using Zygote: pullback
|
using Zygote: pullback
|
||||||
|
|
||||||
trainmode(f, x...) = pullback(f, x...)[1]
|
evalwgrad(f, x...) = pullback(f, x...)[1]
|
||||||
trainmode(f) = (x...) -> trainmode(f, x...)
|
|
||||||
|
|
||||||
@testset "Dropout" begin
|
@testset "Dropout" begin
|
||||||
x = [1.,2.,3.]
|
x = [1.,2.,3.]
|
||||||
@test x == Dropout(0.1)(x)
|
@test x == Dropout(0.1)(x)
|
||||||
@test x == trainmode(Dropout(0), x)
|
@test x == evalwgrad(Dropout(0), x)
|
||||||
@test zero(x) == trainmode(Dropout(1), x)
|
@test zero(x) == evalwgrad(Dropout(1), x)
|
||||||
|
|
||||||
x = rand(100)
|
x = rand(100)
|
||||||
m = Dropout(0.9)
|
m = Dropout(0.9)
|
||||||
y = trainmode(m, x)
|
y = evalwgrad(m, x)
|
||||||
@test count(a->a==0, y) > 50
|
@test count(a->a==0, y) > 50
|
||||||
y = m(x)
|
testmode!(m, true)
|
||||||
|
y = evalwgrad(m, x) # should override istraining
|
||||||
@test count(a->a==0, y) == 0
|
@test count(a->a==0, y) == 0
|
||||||
y = trainmode(m, x)
|
testmode!(m, false)
|
||||||
|
y = evalwgrad(m, x)
|
||||||
@test count(a->a==0, y) > 50
|
@test count(a->a==0, y) > 50
|
||||||
|
|
||||||
x = rand(Float32, 100)
|
x = rand(Float32, 100)
|
||||||
m = Chain(Dense(100,100),
|
m = Chain(Dense(100,100),
|
||||||
Dropout(0.9))
|
Dropout(0.9))
|
||||||
y = trainmode(m, x)
|
y = evalwgrad(m, x)
|
||||||
@test count(a->a == 0, y) > 50
|
@test count(a->a == 0, y) > 50
|
||||||
y = m(x)
|
testmode!(m, true)
|
||||||
|
y = evalwgrad(m, x) # should override istraining
|
||||||
@test count(a->a == 0, y) == 0
|
@test count(a->a == 0, y) == 0
|
||||||
|
|
||||||
x = rand(100, 50)
|
x = rand(100, 50)
|
||||||
|
@ -49,7 +51,7 @@ end
|
||||||
# initial m.σ is 1
|
# initial m.σ is 1
|
||||||
# initial m.μ is 0
|
# initial m.μ is 0
|
||||||
|
|
||||||
y = trainmode(m, x)
|
y = evalwgrad(m, x)
|
||||||
@test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5)
|
@test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5)
|
||||||
# julia> x
|
# julia> x
|
||||||
# 2×3 Array{Float64,2}:
|
# 2×3 Array{Float64,2}:
|
||||||
|
@ -82,19 +84,19 @@ end
|
||||||
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
|
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
|
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
|
||||||
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
|
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
|
||||||
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
|
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
|
||||||
@test m(x) == y
|
@test m(x) == y
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
|
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
|
||||||
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
||||||
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
||||||
@test m(x) == y
|
@test m(x) == y
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
|
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
|
||||||
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
|
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
|
||||||
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
||||||
@test m(x) == y
|
@test m(x) == y
|
||||||
|
@ -117,7 +119,7 @@ end
|
||||||
x = Float64.(x)
|
x = Float64.(x)
|
||||||
@test m.β == [0, 0] # initβ(2)
|
@test m.β == [0, 0] # initβ(2)
|
||||||
@test m.γ == [1, 1] # initγ(2)
|
@test m.γ == [1, 1] # initγ(2)
|
||||||
y = trainmode(m, x)
|
y = evalwgrad(m, x)
|
||||||
|
|
||||||
#julia> x
|
#julia> x
|
||||||
#[:, :, 1] =
|
#[:, :, 1] =
|
||||||
|
@ -162,7 +164,7 @@ end
|
||||||
@test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
|
@test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = trainmode(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3),
|
let m = trainmode!(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3),
|
||||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||||
y = reshape(m(y), sizes...)
|
y = reshape(m(y), sizes...)
|
||||||
|
@ -172,14 +174,14 @@ end
|
||||||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||||
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
|
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
|
||||||
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
||||||
y = trainmode(m, x)
|
y = evalwgrad(m, x)
|
||||||
@test size(m.μ) == (sizes[end - 1], )
|
@test size(m.μ) == (sizes[end - 1], )
|
||||||
@test size(m.σ²) == (sizes[end - 1], )
|
@test size(m.σ²) == (sizes[end - 1], )
|
||||||
@test size(y) == sizes
|
@test size(y) == sizes
|
||||||
end
|
end
|
||||||
|
|
||||||
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
||||||
let m_inorm = trainmode(InstanceNorm(2)), m_bnorm = trainmode(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6),
|
let m_inorm = trainmode!(InstanceNorm(2)), m_bnorm = trainmode!(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6),
|
||||||
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
||||||
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
||||||
end
|
end
|
||||||
|
@ -204,7 +206,7 @@ if VERSION >= v"1.1"
|
||||||
@test m.β == [0, 0, 0, 0] # initβ(32)
|
@test m.β == [0, 0, 0, 0] # initβ(32)
|
||||||
@test m.γ == [1, 1, 1, 1] # initγ(32)
|
@test m.γ == [1, 1, 1, 1] # initγ(32)
|
||||||
|
|
||||||
y = trainmode(m, x)
|
y = evalwgrad(m, x)
|
||||||
|
|
||||||
#julia> x
|
#julia> x
|
||||||
#[:, :, 1] =
|
#[:, :, 1] =
|
||||||
|
@ -263,7 +265,7 @@ if VERSION >= v"1.1"
|
||||||
@test isapprox(y, out, atol = 1.0e-7)
|
@test isapprox(y, out, atol = 1.0e-7)
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = trainmode(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3),
|
let m = trainmode!(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3),
|
||||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||||
y = reshape(m(y), sizes...)
|
y = reshape(m(y), sizes...)
|
||||||
|
@ -273,20 +275,20 @@ if VERSION >= v"1.1"
|
||||||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||||
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
|
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
|
||||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
y = trainmode(m, x)
|
y = evalwgrad(m, x)
|
||||||
@test size(m.μ) == (m.G,1)
|
@test size(m.μ) == (m.G,1)
|
||||||
@test size(m.σ²) == (m.G,1)
|
@test size(m.σ²) == (m.G,1)
|
||||||
@test size(y) == sizes
|
@test size(y) == sizes
|
||||||
end
|
end
|
||||||
|
|
||||||
# show that group norm is the same as instance norm when the group size is the same as the number of channels
|
# show that group norm is the same as instance norm when the group size is the same as the number of channels
|
||||||
let IN = trainmode(InstanceNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,5),
|
let IN = trainmode!(InstanceNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5),
|
||||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
@test IN(x) ≈ GN(x)
|
@test IN(x) ≈ GN(x)
|
||||||
end
|
end
|
||||||
|
|
||||||
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
||||||
let BN = trainmode(BatchNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,1),
|
let BN = trainmode!(BatchNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,1),
|
||||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
@test BN(x) ≈ GN(x)
|
@test BN(x) ≈ GN(x)
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue