Flux.jl/test/layers/normalisation.jl

111 lines
2.7 KiB
Julia
Raw Normal View History

2017-10-26 10:46:12 +00:00
using Flux: testmode!
2018-07-15 15:49:41 +00:00
using Flux.Tracker: data
2017-10-26 10:46:12 +00:00
@testset "Dropout" begin
2017-10-23 08:12:53 +00:00
x = [1.,2.,3.]
2017-10-26 10:46:12 +00:00
@test x == testmode!(Dropout(0.1))(x)
@test x == Dropout(0)(x)
2018-07-18 07:01:06 +00:00
@test zero(x) == Dropout(1)(x)
2017-10-23 08:12:53 +00:00
x = rand(100)
m = Dropout(0.9)
y = m(x)
@test count(a->a==0, y) > 50
2017-10-23 14:23:29 +00:00
testmode!(m)
2017-10-23 08:12:53 +00:00
y = m(x)
@test count(a->a==0, y) == 0
2017-10-23 14:23:29 +00:00
testmode!(m, false)
y = m(x)
@test count(a->a==0, y) > 50
2017-10-23 08:12:53 +00:00
x = rand(100)
m = Chain(Dense(100,100),
Dropout(0.9))
y = m(x)
2017-10-23 09:41:08 +00:00
@test count(a->a == 0, y) > 50
2017-10-23 14:23:29 +00:00
testmode!(m)
2017-10-23 08:12:53 +00:00
y = m(x)
2017-10-23 09:41:08 +00:00
@test count(a->a == 0, y) == 0
2019-01-22 15:51:38 +00:00
x = rand(100, 50)
m = Dropout(0.5)
y = m(x; dims=2)
2019-01-22 15:51:38 +00:00
c = map(i->count(a->a==0, @view y[i, :]), 1:100)
@test minimum(c) == maximum(c)
y = m(x; dims=1)
c = map(i->count(a->a==0, @view y[:, i]), 1:50)
@test minimum(c) == maximum(c)
2017-10-23 08:12:53 +00:00
end
2017-10-30 05:24:35 +00:00
@testset "BatchNorm" begin
2018-07-15 15:49:41 +00:00
let m = BatchNorm(2), x = param([1 3 5;
2 4 6])
2017-10-30 05:24:35 +00:00
@test m.β.data == [0, 0] # initβ(2)
@test m.γ.data == [1, 1] # initγ(2)
# initial m.σ is 1
# initial m.μ is 0
@test m.active
# @test m(x).data ≈ [-1 -1; 0 0; 1 1]'
m(x)
# julia> x
# 2×3 Array{Float64,2}:
# 1.0 3.0 5.0
# 2.0 4.0 6.0
#
# μ of batch will be
# (1. + 3. + 5.) / 3 = 3
# (2. + 4. + 6.) / 3 = 4
#
# ∴ update rule with momentum:
# .1 * 3 + 0 = .3
# .1 * 4 + 0 = .4
@test m.μ reshape([0.3, 0.4], 2, 1)
2018-09-11 11:02:14 +00:00
# julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
2017-10-30 05:24:35 +00:00
# 2×1 Array{Float64,2}:
2018-09-11 11:02:14 +00:00
# 1.3
# 1.3
@test m.σ² .1 .* var(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
2017-10-30 05:24:35 +00:00
testmode!(m)
@test !m.active
x = m(x).data
2018-09-11 12:00:54 +00:00
@test isapprox(x[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5)
2017-10-30 05:24:35 +00:00
end
2017-10-30 05:37:48 +00:00
# with activation function
2018-07-15 15:49:41 +00:00
let m = BatchNorm(2, sigmoid), x = param([1 3 5;
2 4 6])
2017-10-30 05:37:48 +00:00
@test m.active
m(x)
testmode!(m)
@test !m.active
2018-07-15 15:49:41 +00:00
y = m(x).data
2018-07-17 03:54:38 +00:00
@test isapprox(y, data(sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ))), atol = 1.0e-7)
2017-10-30 05:37:48 +00:00
end
let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1))
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
@test m(x) == y
end
let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1))
2018-07-15 15:49:41 +00:00
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
@test m(x) == y
end
let m = BatchNorm(2), x = param(reshape(1:24, 2, 2, 3, 2, 1))
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
@test m(x) == y
end
2017-10-30 05:24:35 +00:00
end