Merge branch 'master' of https://github.com/FluxML/Flux.jl
This commit is contained in:
commit
1eca23e113
|
@ -1,10 +1,37 @@
|
|||
before_script:
|
||||
- export CI_DISABLE_CURNN_TEST=true
|
||||
|
||||
variables:
|
||||
CI_IMAGE_TAG: 'cuda'
|
||||
|
||||
include:
|
||||
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v1/common.yml'
|
||||
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v1/test_v1.0.yml'
|
||||
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v1/test_dev.yml'
|
||||
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v3/common.yml'
|
||||
|
||||
test:dev:
|
||||
allow_failure: true
|
||||
.flux:
|
||||
extends: .test
|
||||
script:
|
||||
- julia -e 'using InteractiveUtils;
|
||||
versioninfo()'
|
||||
- mkdir $JULIA_DEPOT_PATH # Pkg3.jl#325
|
||||
- julia -e 'using Pkg;
|
||||
Pkg.add("CuArrays");'
|
||||
- julia --project -e 'using Pkg;
|
||||
Pkg.instantiate();
|
||||
Pkg.build();
|
||||
Pkg.test(; coverage=true);'
|
||||
|
||||
test:v1.0:
|
||||
extends: .flux
|
||||
variables:
|
||||
CI_VERSION_TAG: 'v1.0'
|
||||
only:
|
||||
- staging
|
||||
- trying
|
||||
|
||||
test:v1.1:
|
||||
extends: .flux
|
||||
variables:
|
||||
CI_VERSION_TAG: 'v1.1'
|
||||
only:
|
||||
- staging
|
||||
- trying
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
The Flux.jl package is licensed under the MIT "Expat" License:
|
||||
|
||||
> Copyright (c) 2016: Mike Innes.
|
||||
> Copyright (c) 2016-19: Julia Computing, INc., Mike Innes and Contributors
|
||||
>
|
||||
> Permission is hereby granted, free of charge, to any person obtaining
|
||||
> a copy of this software and associated documentation files (the
|
||||
|
|
1
NEWS.md
1
NEWS.md
|
@ -12,6 +12,7 @@
|
|||
* New [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/656).
|
||||
* [Data.Iris](https://github.com/FluxML/Flux.jl/pull/652) makes Fisher's Iris dataset available with `Iris.labels` and `Iris.features`.
|
||||
* New [InstanceNorm](https://github.com/FluxML/Flux.jl/pull/634), as popularized by [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
|
||||
* New [GroupNorm](https://github.com/FluxML/Flux.jl/pull/696), as described in [Group Normalization](https://arxiv.org/abs/1803.08494).
|
||||
|
||||
AD Changes:
|
||||
|
||||
|
|
|
@ -19,6 +19,11 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
|
|||
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
||||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
||||
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
|
||||
[extras]
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[targets]
|
||||
test = ["Test"]
|
||||
|
|
|
@ -73,4 +73,5 @@ BatchNorm
|
|||
Dropout
|
||||
AlphaDropout
|
||||
LayerNorm
|
||||
GroupNorm
|
||||
```
|
||||
|
|
|
@ -49,5 +49,12 @@ All optimisers return an object that, when passed to `train!`, will update the p
|
|||
Descent
|
||||
Momentum
|
||||
Nesterov
|
||||
RMSProp
|
||||
ADAM
|
||||
AdaMax
|
||||
ADAGrad
|
||||
ADADelta
|
||||
AMSGrad
|
||||
NADAM
|
||||
ADAMW
|
||||
```
|
||||
|
|
|
@ -6,10 +6,8 @@ using Base: tail
|
|||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, Maxout,
|
||||
RNN, LSTM, GRU,
|
||||
Conv, ConvTranspose, MaxPool, MeanPool, DepthwiseConv,
|
||||
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
|
||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
|
||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||
params, mapleaves, cpu, gpu, f32, f64
|
||||
|
||||
@reexport using NNlib
|
||||
|
|
|
@ -19,16 +19,14 @@ module Iris
|
|||
using DelimitedFiles
|
||||
using ..Data: deps, download_and_verify
|
||||
|
||||
const cache_prefix = ""
|
||||
|
||||
# Uncomment if the iris.data file is cached to cache.julialang.org.
|
||||
# const cache_prefix = "https://cache.julialang.org/"
|
||||
const cache_prefix = "https://cache.julialang.org/"
|
||||
|
||||
function load()
|
||||
isfile(deps("iris.data")) && return
|
||||
|
||||
@info "Downloading iris dataset."
|
||||
download_and_verify("$(cache_prefix)https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data",
|
||||
download_and_verify("$(cache_prefix)http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data",
|
||||
deps("iris.data"),
|
||||
"6f608b71a7317216319b4d27b4d9bc84e6abd734eda7872b71a458569e2656c0")
|
||||
end
|
||||
|
|
|
@ -40,7 +40,24 @@ function Base.show(io::IO, c::Chain)
|
|||
print(io, ")")
|
||||
end
|
||||
|
||||
activations(c::Chain, x) = accumulate((x, m) -> m(x), c.layers, init = x)
|
||||
|
||||
# This is a temporary and naive implementation
|
||||
# it might be replaced in the future for better performance
|
||||
# see issue https://github.com/FluxML/Flux.jl/issues/702
|
||||
# Johnny Chen -- @johnnychen94
|
||||
"""
|
||||
activations(c::Chain, input)
|
||||
Calculate the forward results of each layers in Chain `c` with `input` as model input.
|
||||
"""
|
||||
function activations(c::Chain, input)
|
||||
rst = []
|
||||
for l in c
|
||||
x = get(rst, length(rst), input)
|
||||
push!(rst, l(x))
|
||||
end
|
||||
return rst
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
Dense(in::Integer, out::Integer, σ = identity)
|
||||
|
@ -158,7 +175,7 @@ will construct a `Maxout` layer over 4 internal dense linear layers,
|
|||
each identical in structure (784 inputs, 128 outputs).
|
||||
```julia
|
||||
insize = 784
|
||||
outsie = 128
|
||||
outsize = 128
|
||||
Maxout(()->Dense(insize, outsize), 4)
|
||||
```
|
||||
"""
|
||||
|
|
|
@ -165,6 +165,12 @@ function Base.show(io::IO, l::DepthwiseConv)
|
|||
print(io, ")")
|
||||
end
|
||||
|
||||
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
invoke(a, Tuple{AbstractArray}, x)
|
||||
|
||||
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
||||
"""
|
||||
MaxPool(k)
|
||||
|
||||
|
|
|
@ -286,3 +286,109 @@ function Base.show(io::IO, l::InstanceNorm)
|
|||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
"""
|
||||
Group Normalization.
|
||||
This layer can outperform Batch-Normalization and Instance-Normalization.
|
||||
|
||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
|
||||
ϵ = 1f-5, momentum = 0.1f0)
|
||||
|
||||
``chs`` is the number of channels, the channel dimension of your input.
|
||||
For an array of N dimensions, the (N-1)th index is the channel dimension.
|
||||
|
||||
``G`` is the number of groups along which the statistics would be computed.
|
||||
The number of channels must be an integer multiple of the number of groups.
|
||||
|
||||
Example:
|
||||
```
|
||||
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
|
||||
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used
|
||||
```
|
||||
|
||||
Link : https://arxiv.org/pdf/1803.08494.pdf
|
||||
"""
|
||||
|
||||
mutable struct GroupNorm{F,V,W,N,T}
|
||||
G::T # number of groups
|
||||
λ::F # activation function
|
||||
β::V # bias
|
||||
γ::V # scale
|
||||
μ::W # moving mean
|
||||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Bool
|
||||
end
|
||||
|
||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
GroupNorm(G, λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(G,1), ones(G,1), ϵ, momentum, true)
|
||||
|
||||
function(gn::GroupNorm)(x)
|
||||
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
|
||||
ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work")
|
||||
(size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))")
|
||||
|
||||
dims = length(size(x))
|
||||
groups = gn.G
|
||||
channels = size(x, dims-1)
|
||||
batches = size(x,dims)
|
||||
channels_per_group = div(channels,groups)
|
||||
affine_shape = ones(Int, dims)
|
||||
|
||||
# Output reshaped to (W,H...,C/G,G,N)
|
||||
affine_shape[end-1] = channels
|
||||
|
||||
μ_affine_shape = ones(Int,dims + 1)
|
||||
μ_affine_shape[end-1] = groups
|
||||
|
||||
m = prod(size(x)[1:end-2]) * channels_per_group
|
||||
γ = reshape(gn.γ, affine_shape...)
|
||||
β = reshape(gn.β, affine_shape...)
|
||||
|
||||
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
||||
if !gn.active
|
||||
og_shape = size(x)
|
||||
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||
ϵ = gn.ϵ
|
||||
else
|
||||
T = eltype(x)
|
||||
og_shape = size(x)
|
||||
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
|
||||
μ = mean(y, dims = axes)
|
||||
σ² = mean((y .- μ) .^ 2, dims = axes)
|
||||
|
||||
ϵ = data(convert(T, gn.ϵ))
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, gn.momentum))
|
||||
|
||||
gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches)),dims=2)
|
||||
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches)),dims=2)
|
||||
end
|
||||
|
||||
let λ = gn.λ
|
||||
x̂ = (y .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||
|
||||
# Reshape x̂
|
||||
x̂ = reshape(x̂,og_shape)
|
||||
λ.(γ .* x̂ .+ β)
|
||||
end
|
||||
end
|
||||
|
||||
children(gn::GroupNorm) =
|
||||
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum, gn.active)
|
||||
|
||||
mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
|
||||
GroupNorm(gn,G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum, gn.active)
|
||||
|
||||
_testmode!(gn::GroupNorm, test) = (gn.active = !test)
|
||||
|
||||
function Base.show(io::IO, l::GroupNorm)
|
||||
print(io, "GroupNorm($(join(size(l.β), ", "))")
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
print(io, ")")
|
||||
end
|
||||
|
|
|
@ -41,5 +41,7 @@ end
|
|||
if CuArrays.libcudnn != nothing
|
||||
@info "Testing Flux/CUDNN"
|
||||
include("cudnn.jl")
|
||||
include("curnn.jl")
|
||||
if !haskey(ENV, "CI_DISABLE_CURNN_TEST")
|
||||
include("curnn.jl")
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,63 +1,75 @@
|
|||
using Test, Random
|
||||
import Flux: activations
|
||||
|
||||
@testset "basic" begin
|
||||
@testset "Chain" begin
|
||||
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
|
||||
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
|
||||
# numeric test should be put into testset of corresponding layer
|
||||
@testset "helpers" begin
|
||||
@testset "activations" begin
|
||||
dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax)
|
||||
x = rand(10)
|
||||
@test activations(Chain(), x) == []
|
||||
@test activations(dummy_model, x)[1] == dummy_model[1](x)
|
||||
@test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2]
|
||||
@test activations(Chain(identity, x->:foo), x)[2] == :foo # results include `Any` type
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Chain" begin
|
||||
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
|
||||
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
|
||||
# numeric test should be put into testset of corresponding layer
|
||||
end
|
||||
|
||||
@testset "Dense" begin
|
||||
@test length(Dense(10, 5)(randn(10))) == 5
|
||||
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
||||
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
|
||||
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting
|
||||
|
||||
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1)
|
||||
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
|
||||
|
||||
end
|
||||
|
||||
@testset "Diagonal" begin
|
||||
@test length(Flux.Diagonal(10)(randn(10))) == 10
|
||||
@test length(Flux.Diagonal(10)(1)) == 10
|
||||
@test length(Flux.Diagonal(10)(randn(1))) == 10
|
||||
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
|
||||
|
||||
@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
|
||||
@test Flux.Diagonal(2)([1,2]) == [1,2]
|
||||
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
|
||||
end
|
||||
|
||||
@testset "Maxout" begin
|
||||
# Note that the normal common usage of Maxout is as per the docstring
|
||||
# These are abnormal constructors used for testing purposes
|
||||
|
||||
@testset "Constructor" begin
|
||||
mo = Maxout(() -> identity, 4)
|
||||
input = rand(40)
|
||||
@test mo(input) == input
|
||||
end
|
||||
|
||||
@testset "Dense" begin
|
||||
@test length(Dense(10, 5)(randn(10))) == 5
|
||||
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
||||
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
|
||||
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting
|
||||
|
||||
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1)
|
||||
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
|
||||
|
||||
@testset "simple alternatives" begin
|
||||
mo = Maxout((x -> x, x -> 2x, x -> 0.5x))
|
||||
input = rand(40)
|
||||
@test mo(input) == 2*input
|
||||
end
|
||||
|
||||
@testset "Diagonal" begin
|
||||
@test length(Flux.Diagonal(10)(randn(10))) == 10
|
||||
@test length(Flux.Diagonal(10)(1)) == 10
|
||||
@test length(Flux.Diagonal(10)(randn(1))) == 10
|
||||
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
|
||||
|
||||
@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
|
||||
@test Flux.Diagonal(2)([1,2]) == [1,2]
|
||||
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
|
||||
@testset "complex alternatives" begin
|
||||
mo = Maxout((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x))
|
||||
input = [3.0 2.0]
|
||||
target = [0.5, 0.7].*input
|
||||
@test mo(input) == target
|
||||
end
|
||||
|
||||
@testset "Maxout" begin
|
||||
# Note that the normal common usage of Maxout is as per the docstring
|
||||
# These are abnormal constructors used for testing purposes
|
||||
|
||||
@testset "Constructor" begin
|
||||
mo = Maxout(() -> identity, 4)
|
||||
input = rand(40)
|
||||
@test mo(input) == input
|
||||
end
|
||||
|
||||
@testset "simple alternatives" begin
|
||||
mo = Maxout((x -> x, x -> 2x, x -> 0.5x))
|
||||
input = rand(40)
|
||||
@test mo(input) == 2*input
|
||||
end
|
||||
|
||||
@testset "complex alternatives" begin
|
||||
mo = Maxout((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x))
|
||||
input = [3.0 2.0]
|
||||
target = [0.5, 0.7].*input
|
||||
@test mo(input) == target
|
||||
end
|
||||
|
||||
@testset "params" begin
|
||||
mo = Maxout(()->Dense(32, 64), 4)
|
||||
ps = params(mo)
|
||||
@test length(ps) == 8 #4 alts, each with weight and bias
|
||||
end
|
||||
@testset "params" begin
|
||||
mo = Maxout(()->Dense(32, 64), 4)
|
||||
ps = params(mo)
|
||||
@test length(ps) == 8 #4 alts, each with weight and bias
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -32,4 +32,14 @@ end
|
|||
m2 = DepthwiseConv((2, 2), 3)
|
||||
|
||||
@test size(m2(r), 3) == 3
|
||||
|
||||
x = zeros(Float64, 28, 28, 3, 5)
|
||||
|
||||
m3 = DepthwiseConv((2, 2), 3 => 5)
|
||||
|
||||
@test size(m3(r), 3) == 15
|
||||
|
||||
m4 = DepthwiseConv((2, 2), 3)
|
||||
|
||||
@test size(m4(r), 3) == 3
|
||||
end
|
||||
|
|
|
@ -200,3 +200,114 @@ end
|
|||
end
|
||||
|
||||
end
|
||||
|
||||
@testset "GroupNorm" begin
|
||||
# begin tests
|
||||
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
|
||||
|
||||
let m = GroupNorm(4,2), sizes = (3,4,2),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
|
||||
@test m.β.data == [0, 0, 0, 0] # initβ(32)
|
||||
@test m.γ.data == [1, 1, 1, 1] # initγ(32)
|
||||
|
||||
@test m.active
|
||||
|
||||
m(x)
|
||||
|
||||
#julia> x
|
||||
#[:, :, 1] =
|
||||
# 1.0 4.0 7.0 10.0
|
||||
# 2.0 5.0 8.0 11.0
|
||||
# 3.0 6.0 9.0 12.0
|
||||
#
|
||||
#[:, :, 2] =
|
||||
# 13.0 16.0 19.0 22.0
|
||||
# 14.0 17.0 20.0 23.0
|
||||
# 15.0 18.0 21.0 24.0
|
||||
#
|
||||
# μ will be
|
||||
# (1. + 2. + 3. + 4. + 5. + 6.) / 6 = 3.5
|
||||
# (7. + 8. + 9. + 10. + 11. + 12.) / 6 = 9.5
|
||||
#
|
||||
# (13. + 14. + 15. + 16. + 17. + 18.) / 6 = 15.5
|
||||
# (19. + 20. + 21. + 22. + 23. + 24.) / 6 = 21.5
|
||||
#
|
||||
# μ =
|
||||
# 3.5 15.5
|
||||
# 9.5 21.5
|
||||
#
|
||||
# ∴ update rule with momentum:
|
||||
# (1. - .1) * 0 + .1 * (3.5 + 15.5) / 2 = 0.95
|
||||
# (1. - .1) * 0 + .1 * (9.5 + 21.5) / 2 = 1.55
|
||||
@test m.μ ≈ [0.95, 1.55]
|
||||
|
||||
# julia> mean(var(reshape(x,3,2,2,2),dims=(1,2)).* .1,dims=2) .+ .9*1.
|
||||
# 2-element Array{Tracker.TrackedReal{Float64},1}:
|
||||
# 1.25
|
||||
# 1.25
|
||||
@test m.σ² ≈ mean(squeeze(var(reshape(x,3,2,2,2),dims=(1,2))).*.1,dims=2) .+ .9*1.
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
x′ = m(x).data
|
||||
println(x′[1])
|
||||
@test isapprox(x′[1], (1 - 0.95) / sqrt(1.25 + 1f-5), atol = 1.0e-5)
|
||||
end
|
||||
# with activation function
|
||||
let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
|
||||
μ_affine_shape = ones(Int,length(sizes) + 1)
|
||||
μ_affine_shape[end-1] = 2 # Number of groups
|
||||
|
||||
affine_shape = ones(Int,length(sizes) + 1)
|
||||
affine_shape[end-2] = 2 # Channels per group
|
||||
affine_shape[end-1] = 2 # Number of groups
|
||||
affine_shape[1] = sizes[1]
|
||||
affine_shape[end] = sizes[end]
|
||||
|
||||
og_shape = size(x)
|
||||
|
||||
@test m.active
|
||||
m(x)
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
y = m(x)
|
||||
x_ = reshape(x,affine_shape...)
|
||||
out = reshape(data(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ))),og_shape)
|
||||
@test isapprox(y, out, atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = GroupNorm(2,2), sizes = (2, 4, 1, 2, 3),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||
y = reshape(m(y), sizes...)
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = m(x)
|
||||
@test size(m.μ) == (m.G,1)
|
||||
@test size(m.σ²) == (m.G,1)
|
||||
@test size(y) == sizes
|
||||
end
|
||||
|
||||
# show that group norm is the same as instance norm when the group size is the same as the number of channels
|
||||
let IN = InstanceNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,5),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test IN(x) ≈ GN(x)
|
||||
end
|
||||
|
||||
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
||||
let BN = BatchNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,1),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test BN(x) ≈ GN(x)
|
||||
end
|
||||
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue