diff --git a/src/utils.jl b/src/utils.jl index d3d01a11..b2fe76bf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,12 +1,14 @@ # Arrays -nfan(n_in, n_out) = n_in, n_out #fan-in, fan-out +nfan() = 1, 1 #fan_in, fan_out +nfan(n) = 1, n #A vector is treated as a n×1 matrix +nfan(n_out, n_in) = n_in, n_out #In case of Dense kernels: arranged as matrices nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) #In case of convolution kernels glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...))) glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...))) -he_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / first(nfan(dims...))) -he_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0 / first(nfan(dims...))) +he_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / last(nfan(dims...))) +he_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0 / last(nfan(dims...))) ones(T::Type, dims...) = Base.ones(T, dims...) zeros(T::Type, dims...) = Base.zeros(T, dims...) diff --git a/test/utils.jl b/test/utils.jl index 99492d4e..22b8f26a 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -58,7 +58,9 @@ end Random.seed!(0) @testset "Fan in/out" begin - @test nfan(100, 200) == (100, 200) #For Dense layer + @test nfan() == (1, 1) #For a constant + @test nfan(100) == (1, 100) #For vector + @test nfan(100, 200) == (200, 100) #For Dense layer @test nfan(2, 30, 40) == (2 * 30, 2 * 40) #For 1D Conv layer @test nfan(2, 3, 40, 50) == (2 * 3 * 40, 2 * 3 * 50) #For 2D Conv layer @test nfan(2, 3, 4, 50, 60) == (2 * 3 * 4 * 50, 2 * 3 * 4 * 60) #For 3D Conv layer @@ -67,7 +69,7 @@ end @testset "glorot" begin # glorot_uniform and glorot_normal should both yield a kernel with # variance ≈ 2/(fan_in + fan_out) - for dims ∈ [(100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] + for dims ∈ [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] for init ∈ [glorot_uniform, glorot_normal] v = init(dims...) fan_in, fan_out = nfan(dims...) @@ -79,11 +81,11 @@ end @testset "he" begin # he_uniform and he_normal should both yield a kernel with variance ≈ 2/fan_in - for dims ∈ [(100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] + for dims ∈ [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] for init ∈ [he_uniform, he_normal] v = init(dims...) fan_in, fan_out = nfan(dims...) - σ2 = 2 / fan_in + σ2 = 2 / fan_out @test 0.9σ2 < var(v) < 1.1σ2 end end