Fix Glorot initialization, add He initialization
Should fix the issue reported at https://github.com/FluxML/Flux.jl/issues/442 . Adds He weight initialization as a bonus :-)
This commit is contained in:
parent
967cc1c175
commit
4530ac65c7
10
src/utils.jl
10
src/utils.jl
|
@ -1,6 +1,12 @@
|
|||
# Arrays
|
||||
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0/sum(dims))
|
||||
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0/sum(dims))
|
||||
nfan(n_in, n_out) = n_in, n_out #fan-in, fan-out
|
||||
nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) #In case of convolution kernels
|
||||
|
||||
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
|
||||
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
|
||||
|
||||
he_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / first(nfan(dims...)))
|
||||
he_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0 / first(nfan(dims...)))
|
||||
|
||||
ones(T::Type, dims...) = Base.ones(T, dims...)
|
||||
zeros(T::Type, dims...) = Base.zeros(T, dims...)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
using Flux
|
||||
using Flux: throttle, glorot_uniform, glorot_normal, stack, unstack
|
||||
using StatsBase: std
|
||||
using Flux: throttle, nfan, glorot_uniform, glorot_normal, he_uniform, he_normal,
|
||||
stack, unstack
|
||||
using StatsBase: var
|
||||
using Random
|
||||
using Test
|
||||
|
||||
|
@ -56,18 +57,36 @@ end
|
|||
# Set random seed so that these tests don't fail randomly
|
||||
Random.seed!(0)
|
||||
|
||||
# glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)),
|
||||
# and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out)
|
||||
for (n_in, n_out) in [(100, 100), (100, 400)]
|
||||
v = glorot_uniform(n_in, n_out)
|
||||
@test minimum(v) > -1.1*sqrt(6/(n_in + n_out))
|
||||
@test minimum(v) < -0.9*sqrt(6/(n_in + n_out))
|
||||
@test maximum(v) > 0.9*sqrt(6/(n_in + n_out))
|
||||
@test maximum(v) < 1.1*sqrt(6/(n_in + n_out))
|
||||
@testset "Fan in/out" begin
|
||||
@test nfan(100, 200) == (100, 200) #For Dense layer
|
||||
@test nfan(2, 30, 40) == (2 * 30, 2 * 40) #For 1D Conv layer
|
||||
@test nfan(2, 3, 40, 50) == (2 * 3 * 40, 2 * 3 * 50) #For 2D Conv layer
|
||||
@test nfan(2, 3, 4, 50, 60) == (2 * 3 * 4 * 50, 2 * 3 * 4 * 60) #For 3D Conv layer
|
||||
end
|
||||
|
||||
v = glorot_normal(n_in, n_out)
|
||||
@test std(v) > 0.9*sqrt(2/(n_in + n_out))
|
||||
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
|
||||
@testset "glorot" begin
|
||||
# glorot_uniform and glorot_normal should both yield a kernel with
|
||||
# variance ≈ 2/(fan_in + fan_out)
|
||||
for dims ∈ [(100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)]
|
||||
for init ∈ [glorot_uniform, glorot_normal]
|
||||
v = init(dims...)
|
||||
fan_in, fan_out = nfan(dims...)
|
||||
σ2 = 2 / (fan_in + fan_out)
|
||||
@test 0.9σ2 < var(v) < 1.1σ2
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@testset "he" begin
|
||||
# he_uniform and he_normal should both yield a kernel with variance ≈ 2/fan_in
|
||||
for dims ∈ [(100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)]
|
||||
for init ∈ [he_uniform, he_normal]
|
||||
v = init(dims...)
|
||||
fan_in, fan_out = nfan(dims...)
|
||||
σ2 = 2 / fan_in
|
||||
@test 0.9σ2 < var(v) < 1.1σ2
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in New Issue