2017-10-26 10:46:12 +00:00
|
|
|
|
using Flux: testmode!
|
|
|
|
|
|
|
|
|
|
@testset "Dropout" begin
|
2017-10-23 08:12:53 +00:00
|
|
|
|
x = [1.,2.,3.]
|
2017-10-26 10:46:12 +00:00
|
|
|
|
@test x == testmode!(Dropout(0.1))(x)
|
|
|
|
|
@test x == Dropout(0)(x)
|
|
|
|
|
@test zeros(x) == Dropout(1)(x)
|
2017-10-23 08:12:53 +00:00
|
|
|
|
|
|
|
|
|
x = rand(100)
|
|
|
|
|
m = Dropout(0.9)
|
|
|
|
|
y = m(x)
|
|
|
|
|
@test count(a->a==0, y) > 50
|
2017-10-23 14:23:29 +00:00
|
|
|
|
testmode!(m)
|
2017-10-23 08:12:53 +00:00
|
|
|
|
y = m(x)
|
|
|
|
|
@test count(a->a==0, y) == 0
|
2017-10-23 14:23:29 +00:00
|
|
|
|
testmode!(m, false)
|
|
|
|
|
y = m(x)
|
|
|
|
|
@test count(a->a==0, y) > 50
|
2017-10-23 08:12:53 +00:00
|
|
|
|
|
|
|
|
|
x = rand(100)
|
|
|
|
|
m = Chain(Dense(100,100),
|
|
|
|
|
Dropout(0.9))
|
|
|
|
|
y = m(x)
|
2017-10-23 09:41:08 +00:00
|
|
|
|
@test count(a->a == 0, y) > 50
|
2017-10-23 14:23:29 +00:00
|
|
|
|
testmode!(m)
|
2017-10-23 08:12:53 +00:00
|
|
|
|
y = m(x)
|
2017-10-23 09:41:08 +00:00
|
|
|
|
@test count(a->a == 0, y) == 0
|
2017-10-23 08:12:53 +00:00
|
|
|
|
end
|
2017-10-30 05:24:35 +00:00
|
|
|
|
|
|
|
|
|
@testset "BatchNorm" begin
|
|
|
|
|
let m = BatchNorm(2), x = param([1 2; 3 4; 5 6]')
|
|
|
|
|
|
|
|
|
|
@test m.β.data == [0, 0] # initβ(2)
|
|
|
|
|
@test m.γ.data == [1, 1] # initγ(2)
|
|
|
|
|
# initial m.σ is 1
|
|
|
|
|
# initial m.μ is 0
|
|
|
|
|
@test m.active
|
|
|
|
|
|
|
|
|
|
# @test m(x).data ≈ [-1 -1; 0 0; 1 1]'
|
|
|
|
|
m(x)
|
|
|
|
|
|
|
|
|
|
# julia> x
|
|
|
|
|
# 2×3 Array{Float64,2}:
|
|
|
|
|
# 1.0 3.0 5.0
|
|
|
|
|
# 2.0 4.0 6.0
|
|
|
|
|
#
|
|
|
|
|
# μ of batch will be
|
|
|
|
|
# (1. + 3. + 5.) / 3 = 3
|
|
|
|
|
# (2. + 4. + 6.) / 3 = 4
|
|
|
|
|
#
|
|
|
|
|
# ∴ update rule with momentum:
|
|
|
|
|
# .1 * 3 + 0 = .3
|
|
|
|
|
# .1 * 4 + 0 = .4
|
|
|
|
|
@test m.μ ≈ reshape([0.3, 0.4], 2, 1)
|
|
|
|
|
|
|
|
|
|
# julia> .1 .* std(x, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
|
|
|
|
# 2×1 Array{Float64,2}:
|
|
|
|
|
# 1.14495
|
|
|
|
|
# 1.14495
|
|
|
|
|
@test m.σ ≈ .1 .* std(x.data, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
|
|
|
|
|
|
|
|
|
testmode!(m)
|
|
|
|
|
@test !m.active
|
|
|
|
|
|
|
|
|
|
x′ = m(x).data
|
|
|
|
|
@test x′[1] ≈ (1 - 0.3) / 1.1449489742783179
|
|
|
|
|
end
|
2017-10-30 05:37:48 +00:00
|
|
|
|
|
|
|
|
|
# with activation function
|
2018-04-15 19:04:42 +00:00
|
|
|
|
let m = BatchNorm(2, σ), x = param([1 2; 3 4; 5 6]')
|
2017-10-30 05:37:48 +00:00
|
|
|
|
@test m.active
|
|
|
|
|
m(x)
|
|
|
|
|
|
|
|
|
|
testmode!(m)
|
|
|
|
|
@test !m.active
|
|
|
|
|
|
|
|
|
|
x′ = m(x).data
|
|
|
|
|
@test x′[1] ≈ σ((1 - 0.3) / 1.1449489742783179)
|
|
|
|
|
end
|
2018-03-16 01:48:59 +00:00
|
|
|
|
|
2018-03-16 02:52:09 +00:00
|
|
|
|
let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1))
|
|
|
|
|
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
|
|
|
|
|
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
|
|
|
|
|
@test m(x) == y
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1))
|
|
|
|
|
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
|
|
|
|
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
|
|
|
|
@test m(x) == y
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
let m = BatchNorm(2), x = param(reshape(1:24, 2, 2, 3, 2, 1))
|
|
|
|
|
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
|
|
|
|
|
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
2018-03-16 01:48:59 +00:00
|
|
|
|
@test m(x) == y
|
|
|
|
|
end
|
2017-10-30 05:24:35 +00:00
|
|
|
|
end
|